diff --git a/js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts b/js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts index fe2567e71d49a..9bbad9839d616 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts @@ -62,6 +62,14 @@ const validateInputs = (inputs: readonly TensorView[], attributes: RotaryEmbeddi } } + if (sequenceLength > maxSequenceLength) { + throw new Error('Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported'); + } + + // Note: position_ids value validation is handled by shader-side bounds checks (defense-in-depth). + // We cannot validate position_ids values here because the tensor is GPU-resident — its data field + // is a GPU buffer ID, not a WASM heap pointer, so getBigInt64Array() would read garbage. + if (headSize / 2 !== cosCache.dims[1] && rotaryEmbeddingDim / 2 !== cosCache.dims[1]) { throw new Error( `Input 'cos_cache' dimension 1 should be same as head_size / 2 or rotary_embedding_dim / 2, got ${ @@ -69,10 +77,6 @@ const validateInputs = (inputs: readonly TensorView[], attributes: RotaryEmbeddi }`, ); } - - if (sequenceLength > maxSequenceLength) { - throw new Error('Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported'); - } }; export const createRotaryEmbeddingProgramInfo = ( diff --git a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc index 9f81e490971cd..69d2db391ce3c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc @@ -35,13 +35,28 @@ Status RotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const { " if (global_idx >= size) { return; }\n" " if (bsnh[3] < half_rotary_emb_dim) {\n" << " let position_ids_idx = " << position_ids.BroadcastedIndicesToOffset("bsnh.xy", output_indices) << ";\n" - << " let position_id = u32(" << position_ids.GetByOffset("position_ids_idx") << ") + select(0, bsnh[1], position_ids_idx == 0);\n" + << " let raw_pos = " << position_ids.GetByOffset("position_ids_idx") << ";\n" << " let i = dot(bsnh, uniforms.input_output_stride) + select(0, bsnh[3], " << interleaved_str << ");\n" << " let j = i + select(half_rotary_emb_dim, 1, " << interleaved_str << ");\n" - << " let re = " << input.GetByOffset("i") << " * " << cos_cache.GetByIndices("vec2(position_id, bsnh[3])") << " - " << input.GetByOffset("j") << " * " << sin_cache.GetByIndices("vec2(position_id, bsnh[3])") << ";\n" - << " " << output.SetByOffset("i", "re") << "\n" - << " let im = " << input.GetByOffset("i") << " * " << sin_cache.GetByIndices("vec2(position_id, bsnh[3])") << " + " << input.GetByOffset("j") << " * " << cos_cache.GetByIndices("vec2(position_id, bsnh[3])") << ";\n" - << " " << output.SetByOffset("j", "im") << "\n" + " let max_position = uniforms.cos_cache_shape[0];\n" + // Bounds check: raw_pos < 0 catches negative position_ids (i32 from truncated int64). + // After u32 conversion + offset, check >= max_position catches too-large values. + // On OOB, pass through input unchanged (same as CUDA kernel behavior). + " if (raw_pos < 0) {\n" + << " " << output.SetByOffset("i", input.GetByOffset("i")) << "\n" + << " " << output.SetByOffset("j", input.GetByOffset("j")) << "\n" + " } else {\n" + " let position_id = u32(raw_pos) + select(0, bsnh[1], position_ids_idx == 0);\n" + " if (position_id >= max_position) {\n" + << " " << output.SetByOffset("i", input.GetByOffset("i")) << "\n" + << " " << output.SetByOffset("j", input.GetByOffset("j")) << "\n" + " } else {\n" + << " let re = " << input.GetByOffset("i") << " * " << cos_cache.GetByIndices("vec2(position_id, bsnh[3])") << " - " << input.GetByOffset("j") << " * " << sin_cache.GetByIndices("vec2(position_id, bsnh[3])") << ";\n" + << " " << output.SetByOffset("i", "re") << "\n" + << " let im = " << input.GetByOffset("i") << " * " << sin_cache.GetByIndices("vec2(position_id, bsnh[3])") << " + " << input.GetByOffset("j") << " * " << cos_cache.GetByIndices("vec2(position_id, bsnh[3])") << ";\n" + << " " << output.SetByOffset("j", "im") << "\n" + " }\n" + " }\n" << " } else { \n" " let k = dot(bsnh, uniforms.input_output_stride) + half_rotary_emb_dim;\n" << " " << output.SetByOffset("k", input.GetByOffset("k")) << "\n" @@ -74,24 +89,39 @@ Status FusedQKRotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) c << " let seqlen = u32(seqlen_i);\n" << " let total_seqlen = seqlen + 1u;\n" << " let past_seqlen = total_seqlen - uniforms.q_global_shape[1];\n" + // position_id is derived from past_seqlen + sequence_idx (always non-negative). << " let position_id = past_seqlen + sequence_idx;\n" - << " let cos_v = " << cos_cache.GetByIndices("vec2(position_id, bsnh[3])") << ";\n" - << " let sin_v = " << sin_cache.GetByIndices("vec2(position_id, bsnh[3])") << ";\n" << " let qi = dot(bsnh, uniforms.q_input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n" << " let qj = qi + select(half_rotary_dim, 1u, " << interleaved_str << ");\n" - << " let q_re = " << q_input.GetByOffset("qi") << " * cos_v - " << q_input.GetByOffset("qj") << " * sin_v;\n" - << " " << q_output.SetByOffset("qi", "q_re") << "\n" - << " let q_im = " << q_input.GetByOffset("qi") << " * sin_v + " << q_input.GetByOffset("qj") << " * cos_v;\n" - << " " << q_output.SetByOffset("qj", "q_im") << "\n" + // Bounds check: position_id must be within cos/sin cache range. + // On OOB, pass through input unchanged (same as CUDA kernel behavior). + " let max_position = uniforms.cos_cache_shape[0];\n" + " if (position_id >= max_position) {\n" + << " " << q_output.SetByOffset("qi", q_input.GetByOffset("qi")) << "\n" + << " " << q_output.SetByOffset("qj", q_input.GetByOffset("qj")) << "\n" + << " if (bsnh[2] < uniforms.k_global_shape[2]) {\n" + << " let ki = dot(bsnh, uniforms.k_input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n" + << " let kj = ki + select(half_rotary_dim, 1u, " << interleaved_str << ");\n" + << " " << k_output.SetByOffset("ki", k_input.GetByOffset("ki")) << "\n" + << " " << k_output.SetByOffset("kj", k_input.GetByOffset("kj")) << "\n" + " }\n" + " } else {\n" + << " let cos_v = " << cos_cache.GetByIndices("vec2(position_id, bsnh[3])") << ";\n" + << " let sin_v = " << sin_cache.GetByIndices("vec2(position_id, bsnh[3])") << ";\n" + << " let q_re = " << q_input.GetByOffset("qi") << " * cos_v - " << q_input.GetByOffset("qj") << " * sin_v;\n" + << " " << q_output.SetByOffset("qi", "q_re") << "\n" + << " let q_im = " << q_input.GetByOffset("qi") << " * sin_v + " << q_input.GetByOffset("qj") << " * cos_v;\n" + << " " << q_output.SetByOffset("qj", "q_im") << "\n" // Conditionally process Key (only for heads that exist in K domain) - << " if (bsnh[2] < uniforms.k_global_shape[2]) {\n" - << " let ki = dot(bsnh, uniforms.k_input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n" - << " let kj = ki + select(half_rotary_dim, 1u, " << interleaved_str << ");\n" - << " let k_re = " << k_input.GetByOffset("ki") << " * cos_v - " << k_input.GetByOffset("kj") << " * sin_v;\n" - << " " << k_output.SetByOffset("ki", "k_re") << "\n" - << " let k_im = " << k_input.GetByOffset("ki") << " * sin_v + " << k_input.GetByOffset("kj") << " * cos_v;\n" - << " " << k_output.SetByOffset("kj", "k_im") << "\n" - << " }\n" + << " if (bsnh[2] < uniforms.k_global_shape[2]) {\n" + << " let ki = dot(bsnh, uniforms.k_input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n" + << " let kj = ki + select(half_rotary_dim, 1u, " << interleaved_str << ");\n" + << " let k_re = " << k_input.GetByOffset("ki") << " * cos_v - " << k_input.GetByOffset("kj") << " * sin_v;\n" + << " " << k_output.SetByOffset("ki", "k_re") << "\n" + << " let k_im = " << k_input.GetByOffset("ki") << " * sin_v + " << k_input.GetByOffset("kj") << " * cos_v;\n" + << " " << k_output.SetByOffset("kj", "k_im") << "\n" + " }\n" + " }\n" << " } else {\n" << " let qk = dot(bsnh, uniforms.q_input_output_stride) + half_rotary_dim;\n" << " " << q_output.SetByOffset("qk", q_input.GetByOffset("qk")) << "\n" @@ -127,6 +157,11 @@ Status RotaryEmbedding::ComputeInternal(onnxruntime::webgpu::ComputeContext& con const auto half_rotary_embedding_dim = onnxruntime::narrow(cos_cache->Shape()[1]); const auto head_size = rotary_embedding_dim_ == 0 ? half_rotary_embedding_dim * 2 : hidden_size / num_heads_; + // position_ids bounds validation is handled by shader-side defense-in-depth checks + // (OOB position_ids → pass-through input unchanged). Host-side value scanning is not possible + // because WebGPU program inputs must be GPU buffers (InputMemoryType(OrtMemTypeCPUInput) is + // incompatible with AddInputs). + // Rotary embeddings will be calculated in a pair-wise fashion. In accordance, use the shape // [batch size, sequence length, num of heads, num of pairs to rotate + num of dims to copy] // to unfold the global index in shader. diff --git a/onnxruntime/core/providers/webgpu/llm/rotary_embedding.cc b/onnxruntime/core/providers/webgpu/llm/rotary_embedding.cc index ee46c76f1ea54..234b1d54e69c5 100644 --- a/onnxruntime/core/providers/webgpu/llm/rotary_embedding.cc +++ b/onnxruntime/core/providers/webgpu/llm/rotary_embedding.cc @@ -83,6 +83,13 @@ Status RotaryEmbedding::ComputeInternal(ComputeContext& context) const { if (position_ids != nullptr) { // position_ids provided: cos/sin cache is 2D (max_pos, D/2) + // position_ids bounds validation is handled by shader-side defense-in-depth checks + // (OOB position_ids → pass-through input unchanged). Host-side value scanning is not possible + // because WebGPU program inputs must be GPU buffers (InputMemoryType(OrtMemTypeCPUInput) is + // incompatible with AddInputs). + // Note: ONNX RotaryEmbedding has no base-offset mode (format 0) — position_ids is always + // a 2D tensor (batch_size, sequence_length) when provided. + contrib::webgpu::RotaryEmbeddingProgram program{interleaved_}; program .CacheHint(interleaved_) diff --git a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc index 1fc410c37da14..880c10137f3fe 100644 --- a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc +++ b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc @@ -937,10 +937,11 @@ TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_PositionIds_OOB_CUDA_Passthroug test.AddInput("input", {batch_size, sequence_length, hidden_size}, input_data); // position_id = 2048 exceeds max_sequence_length = 8 — CUDA should pass through input unchanged. test.AddInput("position_ids", {batch_size, sequence_length}, {2048}); + // Non-trivial cache values ensure pass-through (output=input) differs from valid rotary output. test.AddInput("cos_cache", {max_sequence_length, head_size / 2}, - std::vector(max_sequence_length * head_size / 2, 1.0f)); + std::vector(max_sequence_length * head_size / 2, 0.5f)); test.AddInput("sin_cache", {max_sequence_length, head_size / 2}, - std::vector(max_sequence_length * head_size / 2, 0.0f)); + std::vector(max_sequence_length * head_size / 2, 0.866f)); // Output should equal input when position_id is OOB (pass-through). test.AddOutput("output", {batch_size, sequence_length, hidden_size}, input_data); @@ -1054,5 +1055,122 @@ TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_RejectsRank4MalformedCacheWidth {}, nullptr, &execution_providers); } +// Test that OOB position_ids on WebGPU (format 1) pass through input unchanged (shader-side defense). +TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_PositionIds_OOB_WebGPU_Passthrough) { + if (nullptr == DefaultWebGpuExecutionProvider().get()) { + GTEST_SKIP() << "WebGPU execution provider is not available."; + } + + int batch_size = 1; + int sequence_length = 2; + int num_heads = 2; + int head_size = 4; + int max_sequence_length = 8; + int hidden_size = num_heads * head_size; + + OpTester test("RotaryEmbedding", 1, onnxruntime::kMSDomain); + test.AddAttribute("interleaved", static_cast(0)); + + std::vector input_data(batch_size * sequence_length * hidden_size); + for (size_t i = 0; i < input_data.size(); ++i) { + input_data[i] = static_cast(i + 1); + } + + test.AddInput("input", {batch_size, sequence_length, hidden_size}, input_data); + // Both position_ids exceed max_sequence_length = 8 — shader passes through input unchanged. + test.AddInput("position_ids", {batch_size, sequence_length}, {999, 999}); + // Non-trivial cache values ensure pass-through (output=input) differs from valid rotary output. + test.AddInput("cos_cache", {max_sequence_length, head_size / 2}, + std::vector(max_sequence_length * head_size / 2, 0.5f)); + test.AddInput("sin_cache", {max_sequence_length, head_size / 2}, + std::vector(max_sequence_length * head_size / 2, 0.866f)); + + // Output should equal input when position_id is OOB (pass-through). + test.AddOutput("output", {batch_size, sequence_length, hidden_size}, input_data); + test.SetOutputAbsErr("output", 0.0f); + + std::vector> execution_providers; + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// Test that format-0 OOB position_ids base offset passes through on WebGPU (shader-side defense). +TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_PositionIds_Format0_OOB_WebGPU_Passthrough) { + if (nullptr == DefaultWebGpuExecutionProvider().get()) { + GTEST_SKIP() << "WebGPU execution provider is not available."; + } + + int batch_size = 1; + int sequence_length = 2; + int num_heads = 2; + int head_size = 4; + int max_sequence_length = 8; + int hidden_size = num_heads * head_size; + + OpTester test("RotaryEmbedding", 1, onnxruntime::kMSDomain); + test.AddAttribute("interleaved", static_cast(0)); + + std::vector input_data(batch_size * sequence_length * hidden_size); + for (size_t i = 0; i < input_data.size(); ++i) { + input_data[i] = static_cast(i + 1); + } + + test.AddInput("input", {batch_size, sequence_length, hidden_size}, input_data); + // Format 0: base offset 8, effective positions = [8, 9] — both OOB for max_sequence_length = 8. + test.AddInput("position_ids", {1}, {8}); + // Non-trivial cache values ensure pass-through (output=input) differs from valid rotary output. + test.AddInput("cos_cache", {max_sequence_length, head_size / 2}, + std::vector(max_sequence_length * head_size / 2, 0.5f)); + test.AddInput("sin_cache", {max_sequence_length, head_size / 2}, + std::vector(max_sequence_length * head_size / 2, 0.866f)); + + // Output should equal input when all positions are OOB (pass-through). + test.AddOutput("output", {batch_size, sequence_length, hidden_size}, input_data); + test.SetOutputAbsErr("output", 0.0f); + + std::vector> execution_providers; + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// Test that negative position_ids pass through on WebGPU (shader-side defense catches raw_pos < 0). +TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_PositionIds_Negative_WebGPU_Passthrough) { + if (nullptr == DefaultWebGpuExecutionProvider().get()) { + GTEST_SKIP() << "WebGPU execution provider is not available."; + } + + int batch_size = 1; + int sequence_length = 1; + int num_heads = 2; + int head_size = 4; + int max_sequence_length = 8; + int hidden_size = num_heads * head_size; + + OpTester test("RotaryEmbedding", 1, onnxruntime::kMSDomain); + test.AddAttribute("interleaved", static_cast(0)); + + std::vector input_data(hidden_size); + for (int i = 0; i < hidden_size; ++i) { + input_data[i] = static_cast(i + 1); + } + + test.AddInput("input", {batch_size, sequence_length, hidden_size}, input_data); + // Negative position_id — shader checks raw_pos < 0 and passes through. + test.AddInput("position_ids", {batch_size, sequence_length}, {-5}); + // Non-trivial cache values ensure pass-through (output=input) differs from valid rotary output. + test.AddInput("cos_cache", {max_sequence_length, head_size / 2}, + std::vector(max_sequence_length * head_size / 2, 0.5f)); + test.AddInput("sin_cache", {max_sequence_length, head_size / 2}, + std::vector(max_sequence_length * head_size / 2, 0.866f)); + + // Output should equal input when position_id is negative (pass-through). + test.AddOutput("output", {batch_size, sequence_length, hidden_size}, input_data); + test.SetOutputAbsErr("output", 0.0f); + + std::vector> execution_providers; + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/llm/rotary_embedding_op_test.cc b/onnxruntime/test/providers/cpu/llm/rotary_embedding_op_test.cc index 6a3b0d8160d53..2f51b8a7a5690 100644 --- a/onnxruntime/test/providers/cpu/llm/rotary_embedding_op_test.cc +++ b/onnxruntime/test/providers/cpu/llm/rotary_embedding_op_test.cc @@ -1208,10 +1208,11 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_PositionIds_OOB_CUDA_Passthrough) { } test.AddInput("input", {batch_size, sequence_length, hidden_size}, input_data); + // Non-trivial cache values ensure pass-through (output=input) differs from valid rotary output. test.AddInput("cos_cache", {max_sequence_length, head_size / 2}, - std::vector(max_sequence_length * head_size / 2, 1.0f)); + std::vector(max_sequence_length * head_size / 2, 0.5f)); test.AddInput("sin_cache", {max_sequence_length, head_size / 2}, - std::vector(max_sequence_length * head_size / 2, 0.0f)); + std::vector(max_sequence_length * head_size / 2, 0.866f)); // position_id = 2048 exceeds max_sequence_length = 8 — CUDA should pass through input unchanged. test.AddInput("position_ids", {batch_size, sequence_length}, {2048}); @@ -1291,5 +1292,125 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_RejectsRank3HiddenSizeNotDivisibleByNu "hidden_size=5 must be divisible by num_heads=2 for rank-3 input", {}, nullptr, &execution_providers); } +// Test that OOB position_ids on WebGPU pass through input unchanged (shader-side defense). +TEST(RotaryEmbeddingTest, RotaryEmbedding_PositionIds_OOB_WebGPU_Passthrough) { + if (nullptr == DefaultWebGpuExecutionProvider().get()) { + GTEST_SKIP() << "WebGPU execution provider is not available."; + } + + int batch_size = 1; + int sequence_length = 1; + int num_heads = 2; + int head_size = 4; + int max_sequence_length = 8; + int hidden_size = num_heads * head_size; + + OpTester test("RotaryEmbedding", 23, onnxruntime::kOnnxDomain); + test.AddAttribute("interleaved", static_cast(0)); + test.AddAttribute("num_heads", static_cast(num_heads)); + + std::vector input_data(hidden_size); + for (int i = 0; i < hidden_size; ++i) { + input_data[i] = static_cast(i + 1); + } + + test.AddInput("input", {batch_size, sequence_length, hidden_size}, input_data); + // Non-trivial cache values ensure pass-through (output=input) differs from valid rotary output. + test.AddInput("cos_cache", {max_sequence_length, head_size / 2}, + std::vector(max_sequence_length * head_size / 2, 0.5f)); + test.AddInput("sin_cache", {max_sequence_length, head_size / 2}, + std::vector(max_sequence_length * head_size / 2, 0.866f)); + // position_id = 2048 exceeds max_sequence_length = 8 — shader passes through input unchanged. + test.AddInput("position_ids", {batch_size, sequence_length}, {2048}); + + // Output should equal input when position_id is OOB (pass-through). + test.AddOutput("output", {batch_size, sequence_length, hidden_size}, input_data); + test.SetOutputAbsErr("output", 0.0f); + + std::vector> execution_providers; + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// Test that negative position_ids pass through on WebGPU (shader-side defense catches raw_pos < 0). +TEST(RotaryEmbeddingTest, RotaryEmbedding_PositionIds_Negative_WebGPU_Passthrough) { + if (nullptr == DefaultWebGpuExecutionProvider().get()) { + GTEST_SKIP() << "WebGPU execution provider is not available."; + } + + int batch_size = 1; + int sequence_length = 1; + int num_heads = 2; + int head_size = 4; + int max_sequence_length = 8; + int hidden_size = num_heads * head_size; + + OpTester test("RotaryEmbedding", 23, onnxruntime::kOnnxDomain); + test.AddAttribute("interleaved", static_cast(0)); + test.AddAttribute("num_heads", static_cast(num_heads)); + + std::vector input_data(hidden_size); + for (int i = 0; i < hidden_size; ++i) { + input_data[i] = static_cast(i + 1); + } + + test.AddInput("input", {batch_size, sequence_length, hidden_size}, input_data); + // Non-trivial cache values ensure pass-through (output=input) differs from valid rotary output. + test.AddInput("cos_cache", {max_sequence_length, head_size / 2}, + std::vector(max_sequence_length * head_size / 2, 0.5f)); + test.AddInput("sin_cache", {max_sequence_length, head_size / 2}, + std::vector(max_sequence_length * head_size / 2, 0.866f)); + // Negative position_id — shader checks raw_pos < 0 and passes through. + test.AddInput("position_ids", {batch_size, sequence_length}, {-1}); + + // Output should equal input when position_id is negative (pass-through). + test.AddOutput("output", {batch_size, sequence_length, hidden_size}, input_data); + test.SetOutputAbsErr("output", 0.0f); + + std::vector> execution_providers; + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// Test that OOB position_ids in a batch pass through on WebGPU (shader-side defense). +TEST(RotaryEmbeddingTest, RotaryEmbedding_PositionIds_OOB_InBatch_WebGPU_Passthrough) { + if (nullptr == DefaultWebGpuExecutionProvider().get()) { + GTEST_SKIP() << "WebGPU execution provider is not available."; + } + + int batch_size = 2; + int sequence_length = 2; + int num_heads = 2; + int head_size = 4; + int max_sequence_length = 8; + int hidden_size = num_heads * head_size; + + OpTester test("RotaryEmbedding", 23, onnxruntime::kOnnxDomain); + test.AddAttribute("interleaved", static_cast(0)); + test.AddAttribute("num_heads", static_cast(num_heads)); + + std::vector input_data(batch_size * sequence_length * hidden_size); + for (size_t i = 0; i < input_data.size(); ++i) { + input_data[i] = static_cast(i + 1); + } + + test.AddInput("input", {batch_size, sequence_length, hidden_size}, input_data); + // Non-trivial cache values ensure pass-through (output=input) differs from valid rotary output. + test.AddInput("cos_cache", {max_sequence_length, head_size / 2}, + std::vector(max_sequence_length * head_size / 2, 0.5f)); + test.AddInput("sin_cache", {max_sequence_length, head_size / 2}, + std::vector(max_sequence_length * head_size / 2, 0.866f)); + // All OOB position_ids — shader passes through input unchanged. + test.AddInput("position_ids", {batch_size, sequence_length}, {100, 200, 300, 400}); + + // Output should equal input when all position_ids are OOB (pass-through). + test.AddOutput("output", {batch_size, sequence_length, hidden_size}, input_data); + test.SetOutputAbsErr("output", 0.0f); + + std::vector> execution_providers; + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + } // namespace test } // namespace onnxruntime