Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,21 @@ const validateInputs = (inputs: readonly TensorView[], attributes: RotaryEmbeddi
}
}

if (sequenceLength > maxSequenceLength) {
throw new Error('Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported');
}

// Note: position_ids value validation is handled by shader-side bounds checks (defense-in-depth).
// We cannot validate position_ids values here because the tensor is GPU-resident — its data field
// is a GPU buffer ID, not a WASM heap pointer, so getBigInt64Array() would read garbage.

if (headSize / 2 !== cosCache.dims[1] && rotaryEmbeddingDim / 2 !== cosCache.dims[1]) {
throw new Error(
`Input 'cos_cache' dimension 1 should be same as head_size / 2 or rotary_embedding_dim / 2, got ${
cosCache.dims[1]
}`,
);
}

if (sequenceLength > maxSequenceLength) {
throw new Error('Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported');
}
};

export const createRotaryEmbeddingProgramInfo = (
Expand Down
73 changes: 54 additions & 19 deletions onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,28 @@ Status RotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const {
" if (global_idx >= size) { return; }\n"
" if (bsnh[3] < half_rotary_emb_dim) {\n"
<< " let position_ids_idx = " << position_ids.BroadcastedIndicesToOffset("bsnh.xy", output_indices) << ";\n"
<< " let position_id = u32(" << position_ids.GetByOffset("position_ids_idx") << ") + select(0, bsnh[1], position_ids_idx == 0);\n"
<< " let raw_pos = " << position_ids.GetByOffset("position_ids_idx") << ";\n"
Comment thread
tianleiwu marked this conversation as resolved.
<< " let i = dot(bsnh, uniforms.input_output_stride) + select(0, bsnh[3], " << interleaved_str << ");\n"
<< " let j = i + select(half_rotary_emb_dim, 1, " << interleaved_str << ");\n"
<< " let re = " << input.GetByOffset("i") << " * " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << " - " << input.GetByOffset("j") << " * " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
<< " " << output.SetByOffset("i", "re") << "\n"
<< " let im = " << input.GetByOffset("i") << " * " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << " + " << input.GetByOffset("j") << " * " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
<< " " << output.SetByOffset("j", "im") << "\n"
" let max_position = uniforms.cos_cache_shape[0];\n"
// Bounds check: raw_pos < 0 catches negative position_ids (i32 from truncated int64).
// After u32 conversion + offset, check >= max_position catches too-large values.
// On OOB, pass through input unchanged (same as CUDA kernel behavior).
" if (raw_pos < 0) {\n"
<< " " << output.SetByOffset("i", input.GetByOffset("i")) << "\n"
<< " " << output.SetByOffset("j", input.GetByOffset("j")) << "\n"
" } else {\n"
" let position_id = u32(raw_pos) + select(0, bsnh[1], position_ids_idx == 0);\n"
" if (position_id >= max_position) {\n"
<< " " << output.SetByOffset("i", input.GetByOffset("i")) << "\n"
<< " " << output.SetByOffset("j", input.GetByOffset("j")) << "\n"
" } else {\n"
<< " let re = " << input.GetByOffset("i") << " * " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << " - " << input.GetByOffset("j") << " * " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
<< " " << output.SetByOffset("i", "re") << "\n"
<< " let im = " << input.GetByOffset("i") << " * " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << " + " << input.GetByOffset("j") << " * " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
<< " " << output.SetByOffset("j", "im") << "\n"
" }\n"
" }\n"
<< " } else { \n"
" let k = dot(bsnh, uniforms.input_output_stride) + half_rotary_emb_dim;\n"
<< " " << output.SetByOffset("k", input.GetByOffset("k")) << "\n"
Expand Down Expand Up @@ -74,24 +89,39 @@ Status FusedQKRotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) c
<< " let seqlen = u32(seqlen_i);\n"
<< " let total_seqlen = seqlen + 1u;\n"
<< " let past_seqlen = total_seqlen - uniforms.q_global_shape[1];\n"
// position_id is derived from past_seqlen + sequence_idx (always non-negative).
<< " let position_id = past_seqlen + sequence_idx;\n"
<< " let cos_v = " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
<< " let sin_v = " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
<< " let qi = dot(bsnh, uniforms.q_input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n"
<< " let qj = qi + select(half_rotary_dim, 1u, " << interleaved_str << ");\n"
<< " let q_re = " << q_input.GetByOffset("qi") << " * cos_v - " << q_input.GetByOffset("qj") << " * sin_v;\n"
<< " " << q_output.SetByOffset("qi", "q_re") << "\n"
<< " let q_im = " << q_input.GetByOffset("qi") << " * sin_v + " << q_input.GetByOffset("qj") << " * cos_v;\n"
<< " " << q_output.SetByOffset("qj", "q_im") << "\n"
// Bounds check: position_id must be within cos/sin cache range.
// On OOB, pass through input unchanged (same as CUDA kernel behavior).
" let max_position = uniforms.cos_cache_shape[0];\n"
" if (position_id >= max_position) {\n"
<< " " << q_output.SetByOffset("qi", q_input.GetByOffset("qi")) << "\n"
<< " " << q_output.SetByOffset("qj", q_input.GetByOffset("qj")) << "\n"
<< " if (bsnh[2] < uniforms.k_global_shape[2]) {\n"
<< " let ki = dot(bsnh, uniforms.k_input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n"
<< " let kj = ki + select(half_rotary_dim, 1u, " << interleaved_str << ");\n"
<< " " << k_output.SetByOffset("ki", k_input.GetByOffset("ki")) << "\n"
<< " " << k_output.SetByOffset("kj", k_input.GetByOffset("kj")) << "\n"
" }\n"
" } else {\n"
<< " let cos_v = " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
<< " let sin_v = " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
<< " let q_re = " << q_input.GetByOffset("qi") << " * cos_v - " << q_input.GetByOffset("qj") << " * sin_v;\n"
<< " " << q_output.SetByOffset("qi", "q_re") << "\n"
<< " let q_im = " << q_input.GetByOffset("qi") << " * sin_v + " << q_input.GetByOffset("qj") << " * cos_v;\n"
<< " " << q_output.SetByOffset("qj", "q_im") << "\n"
// Conditionally process Key (only for heads that exist in K domain)
<< " if (bsnh[2] < uniforms.k_global_shape[2]) {\n"
<< " let ki = dot(bsnh, uniforms.k_input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n"
<< " let kj = ki + select(half_rotary_dim, 1u, " << interleaved_str << ");\n"
<< " let k_re = " << k_input.GetByOffset("ki") << " * cos_v - " << k_input.GetByOffset("kj") << " * sin_v;\n"
<< " " << k_output.SetByOffset("ki", "k_re") << "\n"
<< " let k_im = " << k_input.GetByOffset("ki") << " * sin_v + " << k_input.GetByOffset("kj") << " * cos_v;\n"
<< " " << k_output.SetByOffset("kj", "k_im") << "\n"
<< " }\n"
<< " if (bsnh[2] < uniforms.k_global_shape[2]) {\n"
<< " let ki = dot(bsnh, uniforms.k_input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n"
<< " let kj = ki + select(half_rotary_dim, 1u, " << interleaved_str << ");\n"
<< " let k_re = " << k_input.GetByOffset("ki") << " * cos_v - " << k_input.GetByOffset("kj") << " * sin_v;\n"
<< " " << k_output.SetByOffset("ki", "k_re") << "\n"
<< " let k_im = " << k_input.GetByOffset("ki") << " * sin_v + " << k_input.GetByOffset("kj") << " * cos_v;\n"
<< " " << k_output.SetByOffset("kj", "k_im") << "\n"
" }\n"
" }\n"
<< " } else {\n"
<< " let qk = dot(bsnh, uniforms.q_input_output_stride) + half_rotary_dim;\n"
<< " " << q_output.SetByOffset("qk", q_input.GetByOffset("qk")) << "\n"
Expand Down Expand Up @@ -127,6 +157,11 @@ Status RotaryEmbedding::ComputeInternal(onnxruntime::webgpu::ComputeContext& con
const auto half_rotary_embedding_dim = onnxruntime::narrow<uint32_t>(cos_cache->Shape()[1]);
const auto head_size = rotary_embedding_dim_ == 0 ? half_rotary_embedding_dim * 2 : hidden_size / num_heads_;

Comment thread
titaiwangms marked this conversation as resolved.
// position_ids bounds validation is handled by shader-side defense-in-depth checks
// (OOB position_ids → pass-through input unchanged). Host-side value scanning is not possible
// because WebGPU program inputs must be GPU buffers (InputMemoryType(OrtMemTypeCPUInput) is
// incompatible with AddInputs).

// Rotary embeddings will be calculated in a pair-wise fashion. In accordance, use the shape
// [batch size, sequence length, num of heads, num of pairs to rotate + num of dims to copy]
// to unfold the global index in shader.
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/providers/webgpu/llm/rotary_embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ Status RotaryEmbedding::ComputeInternal(ComputeContext& context) const {

if (position_ids != nullptr) {
// position_ids provided: cos/sin cache is 2D (max_pos, D/2)
// position_ids bounds validation is handled by shader-side defense-in-depth checks
// (OOB position_ids → pass-through input unchanged). Host-side value scanning is not possible
// because WebGPU program inputs must be GPU buffers (InputMemoryType(OrtMemTypeCPUInput) is
// incompatible with AddInputs).
// Note: ONNX RotaryEmbedding has no base-offset mode (format 0) — position_ids is always
// a 2D tensor (batch_size, sequence_length) when provided.

contrib::webgpu::RotaryEmbeddingProgram program{interleaved_};
program
.CacheHint(interleaved_)
Expand Down
122 changes: 120 additions & 2 deletions onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -937,10 +937,11 @@ TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_PositionIds_OOB_CUDA_Passthroug
test.AddInput<float>("input", {batch_size, sequence_length, hidden_size}, input_data);
// position_id = 2048 exceeds max_sequence_length = 8 — CUDA should pass through input unchanged.
test.AddInput<int64_t>("position_ids", {batch_size, sequence_length}, {2048});
// Non-trivial cache values ensure pass-through (output=input) differs from valid rotary output.
test.AddInput<float>("cos_cache", {max_sequence_length, head_size / 2},
std::vector<float>(max_sequence_length * head_size / 2, 1.0f));
std::vector<float>(max_sequence_length * head_size / 2, 0.5f));
test.AddInput<float>("sin_cache", {max_sequence_length, head_size / 2},
std::vector<float>(max_sequence_length * head_size / 2, 0.0f));
std::vector<float>(max_sequence_length * head_size / 2, 0.866f));

// Output should equal input when position_id is OOB (pass-through).
test.AddOutput<float>("output", {batch_size, sequence_length, hidden_size}, input_data);
Expand Down Expand Up @@ -1054,5 +1055,122 @@ TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_RejectsRank4MalformedCacheWidth
{}, nullptr, &execution_providers);
}

// Test that OOB position_ids on WebGPU (format 1) pass through input unchanged (shader-side defense).
TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_PositionIds_OOB_WebGPU_Passthrough) {
if (nullptr == DefaultWebGpuExecutionProvider().get()) {
GTEST_SKIP() << "WebGPU execution provider is not available.";
}

int batch_size = 1;
int sequence_length = 2;
Comment thread
tianleiwu marked this conversation as resolved.
int num_heads = 2;
int head_size = 4;
int max_sequence_length = 8;
int hidden_size = num_heads * head_size;

OpTester test("RotaryEmbedding", 1, onnxruntime::kMSDomain);
test.AddAttribute<int64_t>("interleaved", static_cast<int64_t>(0));

std::vector<float> input_data(batch_size * sequence_length * hidden_size);
for (size_t i = 0; i < input_data.size(); ++i) {
input_data[i] = static_cast<float>(i + 1);
}

test.AddInput<float>("input", {batch_size, sequence_length, hidden_size}, input_data);
// Both position_ids exceed max_sequence_length = 8 — shader passes through input unchanged.
test.AddInput<int64_t>("position_ids", {batch_size, sequence_length}, {999, 999});
// Non-trivial cache values ensure pass-through (output=input) differs from valid rotary output.
test.AddInput<float>("cos_cache", {max_sequence_length, head_size / 2},
std::vector<float>(max_sequence_length * head_size / 2, 0.5f));
test.AddInput<float>("sin_cache", {max_sequence_length, head_size / 2},
std::vector<float>(max_sequence_length * head_size / 2, 0.866f));

// Output should equal input when position_id is OOB (pass-through).
test.AddOutput<float>("output", {batch_size, sequence_length, hidden_size}, input_data);
test.SetOutputAbsErr("output", 0.0f);

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultWebGpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

// Test that format-0 OOB position_ids base offset passes through on WebGPU (shader-side defense).
TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_PositionIds_Format0_OOB_WebGPU_Passthrough) {
if (nullptr == DefaultWebGpuExecutionProvider().get()) {
GTEST_SKIP() << "WebGPU execution provider is not available.";
}

int batch_size = 1;
int sequence_length = 2;
int num_heads = 2;
int head_size = 4;
int max_sequence_length = 8;
int hidden_size = num_heads * head_size;

OpTester test("RotaryEmbedding", 1, onnxruntime::kMSDomain);
test.AddAttribute<int64_t>("interleaved", static_cast<int64_t>(0));

std::vector<float> input_data(batch_size * sequence_length * hidden_size);
for (size_t i = 0; i < input_data.size(); ++i) {
input_data[i] = static_cast<float>(i + 1);
}

test.AddInput<float>("input", {batch_size, sequence_length, hidden_size}, input_data);
// Format 0: base offset 8, effective positions = [8, 9] — both OOB for max_sequence_length = 8.
test.AddInput<int64_t>("position_ids", {1}, {8});
// Non-trivial cache values ensure pass-through (output=input) differs from valid rotary output.
test.AddInput<float>("cos_cache", {max_sequence_length, head_size / 2},
std::vector<float>(max_sequence_length * head_size / 2, 0.5f));
test.AddInput<float>("sin_cache", {max_sequence_length, head_size / 2},
std::vector<float>(max_sequence_length * head_size / 2, 0.866f));

// Output should equal input when all positions are OOB (pass-through).
test.AddOutput<float>("output", {batch_size, sequence_length, hidden_size}, input_data);
test.SetOutputAbsErr("output", 0.0f);

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultWebGpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

// Test that negative position_ids pass through on WebGPU (shader-side defense catches raw_pos < 0).
TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_PositionIds_Negative_WebGPU_Passthrough) {
if (nullptr == DefaultWebGpuExecutionProvider().get()) {
GTEST_SKIP() << "WebGPU execution provider is not available.";
}

int batch_size = 1;
int sequence_length = 1;
int num_heads = 2;
int head_size = 4;
int max_sequence_length = 8;
int hidden_size = num_heads * head_size;

OpTester test("RotaryEmbedding", 1, onnxruntime::kMSDomain);
test.AddAttribute<int64_t>("interleaved", static_cast<int64_t>(0));

std::vector<float> input_data(hidden_size);
for (int i = 0; i < hidden_size; ++i) {
input_data[i] = static_cast<float>(i + 1);
}

test.AddInput<float>("input", {batch_size, sequence_length, hidden_size}, input_data);
// Negative position_id — shader checks raw_pos < 0 and passes through.
test.AddInput<int64_t>("position_ids", {batch_size, sequence_length}, {-5});
// Non-trivial cache values ensure pass-through (output=input) differs from valid rotary output.
test.AddInput<float>("cos_cache", {max_sequence_length, head_size / 2},
std::vector<float>(max_sequence_length * head_size / 2, 0.5f));
test.AddInput<float>("sin_cache", {max_sequence_length, head_size / 2},
std::vector<float>(max_sequence_length * head_size / 2, 0.866f));

// Output should equal input when position_id is negative (pass-through).
test.AddOutput<float>("output", {batch_size, sequence_length, hidden_size}, input_data);
test.SetOutputAbsErr("output", 0.0f);

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultWebGpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

} // namespace test
} // namespace onnxruntime
Loading
Loading