Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into hualxie/config_op
Browse files Browse the repository at this point in the history
  • Loading branch information
hualxie committed Feb 11, 2025
2 parents 839c79b + e666503 commit 034b06d
Show file tree
Hide file tree
Showing 115 changed files with 4,613 additions and 1,471 deletions.
7 changes: 7 additions & 0 deletions .config/1espt/PipelineAutobaseliningConfig.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ pipelines:
lastModifiedDate: 2024-10-25
armory:
lastModifiedDate: 2024-10-25
binary:
credscan:
lastModifiedDate: 2025-02-06
binskim:
lastModifiedDate: 2025-02-06
spotbugs:
lastModifiedDate: 2025-02-06
usedNonDefaultBranch: true
1299:
retail:
Expand Down
14 changes: 14 additions & 0 deletions .config/guardian/.gdnbaselines
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@
"createdDate": "2024-11-13 11:20:17Z",
"expirationDate": "2025-05-02 11:55:15Z",
"justification": "This error is baselined with an expiration date of 180 days from 2024-11-13 11:55:15Z"
},
"6f6606e50e82b2d3c823c435151f4b69c1fbde92f274753b793d948856cfc462": {
"signature": "6f6606e50e82b2d3c823c435151f4b69c1fbde92f274753b793d948856cfc462",
"alternativeSignatures": [],
"target": "ScanTelemetry_20250206154816289.json",
"line": 1,
"memberOf": [
"default"
],
"tool": "credscan",
"ruleId": "CSCAN-AZURE0130",
"createdDate": "2025-02-06 15:53:46Z",
"expirationDate": "2025-07-26 16:26:55Z",
"justification": "This error is baselined with an expiration date of 180 days from 2025-02-06 16:26:55Z"
}
}
}
9 changes: 9 additions & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ option(onnxruntime_USE_WEBGPU "Build with WebGPU support. Enable WebGPU via C/C+
option(onnxruntime_USE_EXTERNAL_DAWN "Build with treating Dawn as external dependency. Will not link Dawn at build time." OFF)
option(onnxruntime_CUSTOM_DAWN_SRC_PATH "Path to custom Dawn src dir.")
option(onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY "Build Dawn as a monolithic library" OFF)
option(onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP "Adding frame present for PIX to capture a frame" OFF)
# The following 2 options are only for Windows
option(onnxruntime_ENABLE_DAWN_BACKEND_VULKAN "Enable Vulkan backend for Dawn (on Windows)" OFF)
option(onnxruntime_ENABLE_DAWN_BACKEND_D3D12 "Enable D3D12 backend for Dawn (on Windows)" ON)
Expand Down Expand Up @@ -1038,6 +1039,14 @@ if (onnxruntime_USE_WEBGPU)
if (onnxruntime_ENABLE_DAWN_BACKEND_D3D12)
list(APPEND ORT_PROVIDER_FLAGS -DDAWN_ENABLE_D3D12=1)
endif()
if (onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP)
if (NOT onnxruntime_ENABLE_DAWN_BACKEND_D3D12 OR NOT WIN32)
message(
FATAL_ERROR
"Option onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP can only be set on windows with onnxruntime_ENABLE_DAWN_BACKEND_D3D12 is enabled.")
endif()
add_compile_definitions(ENABLE_PIX_FOR_WEBGPU_EP)
endif()
endif()
if (onnxruntime_USE_CANN)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_CANN=1)
Expand Down
10 changes: 5 additions & 5 deletions cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.8.5.zi
googletest;https://github.com/google/googletest/archive/refs/tags/v1.15.0.zip;9d2d0af8d77ac726ea55d44a8fa727ec98311349
#xnnpack 2024.09.04
googlexnnpack;https://github.com/google/XNNPACK/archive/fe98e0b93565382648129271381c14d6205255e3.zip;14f61dcf17cec2cde34ba2dcf61d6f24bf6059f3
json;https://github.com/nlohmann/json/archive/refs/tags/v3.10.5.zip;f257f8dc27c5b8c085dc887b40cddd18ae1f725c
json;https://github.com/nlohmann/json/archive/refs/tags/v3.11.3.zip;5e88795165cc8590138d1f47ce94ee567b85b4d6
microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14
microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41
mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063
onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.17.0.zip;13a60ac5217c104139ce0fd024f48628e7bcf5bc
# Use the latest commit of 10.7-GA
onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/9c69a24bc2e20c8a511a4e6b06fd49639ec5300a.zip;ff1fe9af78eb129b4a4cdcb7450b7390b4436dd3
# Use the latest commit of 10.8-GA
onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/118ed0aea197fa9a7d3ea66180a1d5ddb9deecc3.zip;b78aed3728ad4daf6dc47ea10c1d243dee1d95b1
protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa
protoc_win64;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip;b4521f7ada5b260380f94c4bd7f1b7684c76969a
protoc_win32;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win32.zip;3688010318192c46ce73213cdfb6b3e5656da874
Expand All @@ -47,14 +47,14 @@ protoc_linux_aarch64;https://github.com/protocolbuffers/protobuf/releases/downlo
protoc_mac_universal;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-osx-universal_binary.zip;23710c3d1c2036d8d65a6a22234372fa2d7af9ef
psimd;https://github.com/Maratyszcza/psimd/archive/072586a71b55b7f8c584153d223e95687148a900.zip;1f5454b01f06f9656b77e4a5e2e31d7422487013
pthreadpool;https://github.com/google/pthreadpool/archive/4e80ca24521aa0fb3a746f9ea9c3eaa20e9afbb0.zip;bd4ea65c8292801e9555b527a0ecbb2e0092c917
pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.13.1.zip;9255d5c8568debcc329dd42ed8f410ee139ac7b1
pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.13.6.zip;f780292da9db273c8ef06ccf5fd4b623624143e9
pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/8a1772a0c5c447df2d18edf33ec4603a8c9c04a6.zip;85bf8a60dae026b99b6ccd78606c85ed83bfb2cd
re2;https://github.com/google/re2/archive/refs/tags/2024-07-02.zip;646e1728269cde7fcef990bf4a8e87b047882e88
safeint;https://github.com/dcleblanc/SafeInt/archive/refs/tags/3.0.28.zip;23f252040ff6cb9f1fd18575b32fa8fb5928daac
tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381
cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.5.1.zip;e49b2b964163d27765a5002d210a2f3c73771835
utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156
extensions;https://github.com/microsoft/onnxruntime-extensions/archive/f3f6caa6e8adb420e005ec41c6fefc8d75affb6e.zip;cec2e164f1a00e7d80fd94df65e4e8d2daead70d
extensions;https://github.com/microsoft/onnxruntime-extensions/archive/c24b7bab0c12f53da76d0c31b03b9f0f8ec8f3b4.zip;239063aee4946a9af147b473a4c3da78ba7413b4
composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.7.0.zip;d0753d8d5b39947ca0729d7773cb84653a129eb1
Expand Down
27 changes: 24 additions & 3 deletions cmake/external/onnxruntime_external_deps.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -600,16 +600,15 @@ endif()

if(onnxruntime_ENABLE_DLPACK)
message(STATUS "dlpack is enabled.")

onnxruntime_fetchcontent_declare(
dlpack
URL ${DEP_URL_dlpack}
URL_HASH SHA1=${DEP_SHA1_dlpack}
EXCLUDE_FROM_ALL
FIND_PACKAGE_ARGS NAMES dlpack
)
# We can't use onnxruntime_fetchcontent_makeavailable since some part of the the dlpack code is Linux only.
# For example, dlpackcpp.h uses posix_memalign.
FetchContent_Populate(dlpack)
onnxruntime_fetchcontent_makeavailable(dlpack)
endif()

if(onnxruntime_ENABLE_TRAINING OR (onnxruntime_ENABLE_TRAINING_APIS AND onnxruntime_BUILD_UNIT_TESTS))
Expand Down Expand Up @@ -686,6 +685,24 @@ if (onnxruntime_USE_WEBGPU)
set(DAWN_ENABLE_INSTALL OFF CACHE BOOL "" FORCE)
endif()

if (onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP)
set(DAWN_ENABLE_DESKTOP_GL ON CACHE BOOL "" FORCE)
set(DAWN_ENABLE_OPENGLES ON CACHE BOOL "" FORCE)
set(DAWN_SUPPORTS_GLFW_FOR_WINDOWING ON CACHE BOOL "" FORCE)
set(DAWN_USE_GLFW ON CACHE BOOL "" FORCE)
set(DAWN_USE_WINDOWS_UI ON CACHE BOOL "" FORCE)
set(TINT_BUILD_GLSL_WRITER ON CACHE BOOL "" FORCE)
set(TINT_BUILD_GLSL_VALIDATOR ON CACHE BOOL "" FORCE)
else()
set(DAWN_ENABLE_DESKTOP_GL OFF CACHE BOOL "" FORCE)
set(DAWN_ENABLE_OPENGLES OFF CACHE BOOL "" FORCE)
set(DAWN_SUPPORTS_GLFW_FOR_WINDOWING OFF CACHE BOOL "" FORCE)
set(DAWN_USE_GLFW OFF CACHE BOOL "" FORCE)
set(DAWN_USE_WINDOWS_UI OFF CACHE BOOL "" FORCE)
set(TINT_BUILD_GLSL_WRITER OFF CACHE BOOL "" FORCE)
set(TINT_BUILD_GLSL_VALIDATOR OFF CACHE BOOL "" FORCE)
endif()

# disable things we don't use
set(DAWN_DXC_ENABLE_ASSERTS_IN_NDEBUG OFF)
set(DAWN_ENABLE_DESKTOP_GL OFF CACHE BOOL "" FORCE)
Expand Down Expand Up @@ -742,6 +759,10 @@ if (onnxruntime_USE_WEBGPU)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_proc)
endif()
endif()

if (onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES glfw webgpu_glfw)
endif()
endif()

if(onnxruntime_USE_COREML)
Expand Down
7 changes: 2 additions & 5 deletions cmake/onnxruntime_framework.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,15 @@ endif()
if (onnxruntime_ENABLE_TRAINING_OPS)
target_include_directories(onnxruntime_framework PRIVATE ${ORTTRAINING_ROOT})
if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP OR onnxruntime_ENABLE_TRITON)
onnxruntime_add_include_to_target(onnxruntime_framework Python::Module)
target_include_directories(onnxruntime_framework PRIVATE ${dlpack_SOURCE_DIR}/include)
onnxruntime_add_include_to_target(onnxruntime_framework Python::Module dlpack::dlpack)
endif()
endif()
if (onnxruntime_USE_MPI)
target_include_directories(onnxruntime_framework PUBLIC ${MPI_CXX_INCLUDE_DIRS})
endif()

if (onnxruntime_ENABLE_ATEN)
# DLPack is a header-only dependency
set(DLPACK_INCLUDE_DIR ${dlpack_SOURCE_DIR}/include)
target_include_directories(onnxruntime_framework PRIVATE ${DLPACK_INCLUDE_DIR})
onnxruntime_add_include_to_target(onnxruntime_framework dlpack::dlpack)
endif()
onnxruntime_add_include_to_target(onnxruntime_framework onnxruntime_common onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers safeint_interface Boost::mp11 nlohmann_json::nlohmann_json)

Expand Down
8 changes: 8 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/cast.cpp
${MLAS_SRC_DIR}/rotary_embedding.h
${MLAS_SRC_DIR}/rotary_embedding.cpp
${MLAS_SRC_DIR}/softmax.h
)

target_sources(onnxruntime_mlas PRIVATE
Expand Down Expand Up @@ -97,6 +98,9 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/softmax_kernel_neon.h
${MLAS_SRC_DIR}/softmax_kernel_neon.cpp
${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp
)

set(mlas_platform_preprocess_srcs
Expand Down Expand Up @@ -377,6 +381,8 @@ else()
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/softmax_kernel_neon.h
${MLAS_SRC_DIR}/softmax_kernel_neon.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod")
Expand All @@ -398,6 +404,7 @@ else()
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
Expand All @@ -411,6 +418,7 @@ else()
set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
endif()

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
Expand Down
3 changes: 1 addition & 2 deletions cmake/onnxruntime_providers_cpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,7 @@ endif()
if (onnxruntime_ENABLE_DLPACK)
target_compile_definitions(onnxruntime_providers PRIVATE ENABLE_DLPACK)
# DLPack is a header-only dependency
set(DLPACK_INCLUDE_DIR ${dlpack_SOURCE_DIR}/include)
target_include_directories(onnxruntime_providers PRIVATE ${DLPACK_INCLUDE_DIR})
onnxruntime_add_include_to_target(onnxruntime_providers dlpack::dlpack)
endif()

if (onnxruntime_ENABLE_TRAINING)
Expand Down
8 changes: 4 additions & 4 deletions cmake/onnxruntime_providers_openvino.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

# Header paths
find_package(OpenVINO REQUIRED COMPONENTS Runtime ONNX)
if(OpenVINO_VERSION VERSION_LESS 2024.4)
message(FATAL_ERROR "OpenVINO 2024.4 and newer are supported. Please, use latest OpenVINO release")
if(OpenVINO_VERSION VERSION_LESS 2024.5)
message(FATAL_ERROR "OpenVINO 2024.5 and newer are supported. Please, use latest OpenVINO release")
endif()

if(OpenVINO_VERSION VERSION_GREATER_EQUAL 2024.4)
Expand All @@ -30,7 +30,7 @@
endif()

list(APPEND OPENVINO_LIB_LIST openvino::frontend::onnx openvino::runtime ${PYTHON_LIBRARIES})
if ((DEFINED ENV{OPENCL_LIBS}) AND (DEFINED ENV{OPENCL_INCS}))
if ((DEFINED ENV{OPENCL_LIBS}) AND (DEFINED ENV{OPENCL_INCS}) AND onnxruntime_USE_OPENVINO_GPU)
add_definitions(-DIO_BUFFER_ENABLED=1)
list(APPEND OPENVINO_LIB_LIST $ENV{OPENCL_LIBS})
endif()
Expand Down Expand Up @@ -86,4 +86,4 @@
set_target_properties(onnxruntime_providers_openvino PROPERTIES
MAP_IMPORTED_CONFIG_RELEASE RelWithDebInfo
MAP_IMPORTED_CONFIG_DEBUG RelWithDebInfo
)
)
2 changes: 1 addition & 1 deletion cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ if (onnxruntime_ENABLE_ATEN)
endif()

if (onnxruntime_ENABLE_DLPACK)
target_include_directories(onnxruntime_pybind11_state PRIVATE ${dlpack_SOURCE_DIR}/include)
target_link_libraries(onnxruntime_pybind11_state PRIVATE dlpack::dlpack)
endif()

if (onnxruntime_ENABLE_TRAINING)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ static const char* const kOrtSessionOptionEpContextEnable = "ep.context_enable";

// Specify the file path for the Onnx model which has EP context.
// Default to original_file_name_ctx.onnx if not specified
// Folder is not a valid option
static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_path";

// Flag to specify whether to dump the EP context into the Onnx model.
Expand Down
36 changes: 24 additions & 12 deletions js/web/lib/wasm/jsep/webgpu/ops/resize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,16 @@ const validateInputs = (
}
};

const getSafeIntegerDivision = (a: string, b: string, c: string, dType: string): string => `
// The whole part and the fractional part are calculated separately due to inaccuracy of floating
// point division. As an example, f32(21) / f32(7) may evaluate to 2.99... instead of 3, causing an
// offset-by-one error later in floor().
let big = (${a}) * (${b});
let whole = ${dType}(big / (${c}));
let fract = ${dType}(big % (${c})) / ${dType}(${c});
return whole + fract;
`;

const getOriginalCoordinateFromResizedCoordinate = (
coordinateTransferMode: CoordinateTransformMode,
dType: string,
Expand All @@ -166,7 +176,13 @@ const getOriginalCoordinateFromResizedCoordinate = (
(() => {
switch (coordinateTransferMode) {
case 'asymmetric':
return `return ${dType}(xResized) / ${dType}(xScale);`;
return `
if (xScale < 1.0 || floor(xScale) != xScale) {
return ${dType}(xResized) / ${dType}(xScale);
} else {
${getSafeIntegerDivision('xResized', 'lengthOriginal', 'lengthResized', dType)}
}
`;
case 'pytorch_half_pixel':
return `if (lengthResized > 1) {
return (${dType}(xResized) + 0.5) / ${dType}(xScale) - 0.5;
Expand All @@ -179,13 +195,7 @@ const getOriginalCoordinateFromResizedCoordinate = (
return `if (lengthResized == 1) {
return 0.0;
} else {
// The whole part and the fractional part are calculated separately due to inaccuracy of floating
// point division. As an example, f32(21) / f32(7) may evaluate to 2.99... instead of 3, causing an
// offset-by-one error later in floor().
let whole = ${dType}(xResized * (lengthOriginal - 1) / (lengthResized - 1));
let fract =
${dType}(xResized * (lengthOriginal - 1) % (lengthResized - 1)) / ${dType}(lengthResized - 1);
return whole + fract;
${getSafeIntegerDivision('xResized', 'lengthOriginal - 1', 'lengthResized - 1', dType)}
}`;
case 'tf_crop_and_resize':
return `if (lengthResized > 1) {
Expand Down Expand Up @@ -375,7 +385,7 @@ const calculateInputIndicesFromOutputIndices = (
input_index = u32(original_idx);
}
}
${input.indicesSet('input_indices', 'i', ' input_index')}
${input.indicesSet('input_indices', 'i', 'input_index')}
}
return input_indices;
}`;
Expand Down Expand Up @@ -758,9 +768,11 @@ const createResizeProgramInfo = (
return {
name: 'Resize',
shaderCache: {
hint: `${attributes.cacheKey}|${opsetVersion}|${scales.length > 0 ? scales : ''}|${
sizes.length > 0 ? sizes : ''
}|${roi.length > 0 ? roi : ''}|${noScale}|${inputShape}`,
hint: `${attributes.cacheKey}|${opsetVersion}|${
scales.length > 0 ? (attributes.mode === 'cubic' ? scales : scales.length) : ''
}|${sizes.length > 0 ? sizes : ''}|${roi.length > 0 ? roi : ''}|${noScale}|${
attributes.mode === 'nearest' ? inputShape.length : inputShape
}`,
inputDependencies: ['rank'],
},
getShaderSource,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "contrib_ops/cpu/transformers/beam_search_parameters.h"
#include "core/platform/env_var_utils.h"

namespace onnxruntime {
namespace contrib {
Expand Down Expand Up @@ -136,7 +137,11 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
temperature = 1.0f;
}
}

// The following parameter is read from environment variable for testing purpose.
use_fast_topk = ParseEnvironmentVariableWithDefault<bool>(kBeamSearchUseFastTopK, true);
}

void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) {
// Override vocab_size using the inferred shape from the decoder subgraph ONLY IF
// the vocab_size hasn't been explicitly specified by the user (as an attribute of BeamSearch)
Expand Down
8 changes: 7 additions & 1 deletion onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include <gsl/gsl>
#include "core/framework/allocator.h"
#include "core/framework/ort_value.h"
#include "contrib_ops/cpu/utils/debug_macros.h"
#include "contrib_ops/cpu/utils/console_dumper.h"

namespace onnxruntime {

Expand Down Expand Up @@ -199,8 +199,14 @@ struct IGenerationParameters {
int extra_decoding_ids_input_id = -1;
int cross_qk_output_id = -1;
int no_speech_probs_output_id = -1;

// Parameter for testing slow topk path. It can be updated by the below environment variable.
bool use_fast_topk = true;
};

// Environment variable to enable/disable fast topk kernel on GPU. Default is 1 (enabled).
constexpr const char* kBeamSearchUseFastTopK = "ORT_BEAM_SEARCH_USE_FAST_TOPK";

} // namespace transformers
} // namespace contrib
} // namespace onnxruntime
Loading

0 comments on commit 034b06d

Please sign in to comment.