Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,19 @@ class TensorViewImpl implements TensorView {
throw new Error('Invalid data type');
}
const elementCount = ShapeUtil.size(this.dims);
return elementCount === 0
? new BigInt64Array()
: new BigInt64Array(this.module.HEAP8.buffer, this.data, elementCount);
if (elementCount === 0) {
return new BigInt64Array();
}
// BigInt64Array requires the byte offset to be a multiple of 8. WASM allocators may return
Comment thread
tianleiwu marked this conversation as resolved.
Outdated
// offsets that are not 8-byte aligned, so fall back to copying bytes into an aligned buffer.
// Note: the returned array is a read-only copy when unaligned (mutations won't propagate to WASM heap).
if (this.data % 8 === 0) {
return new BigInt64Array(this.module.HEAP8.buffer, this.data, elementCount);
}
const byteLength = elementCount * 8;
const alignedBuffer = new ArrayBuffer(byteLength);
new Uint8Array(alignedBuffer).set(new Uint8Array(this.module.HEAP8.buffer, this.data, byteLength));
return new BigInt64Array(alignedBuffer);
}

getInt32Array(): Int32Array {
Expand Down
12 changes: 8 additions & 4 deletions js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,21 @@ const validateInputs = (inputs: readonly TensorView[], attributes: RotaryEmbeddi
}
}

if (sequenceLength > maxSequenceLength) {
throw new Error('Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported');
}

// Note: position_ids value validation is handled by shader-side bounds checks (defense-in-depth).
// We cannot validate position_ids values here because the tensor is GPU-resident — its data field
// is a GPU buffer ID, not a WASM heap pointer, so getBigInt64Array() would read garbage.

if (headSize / 2 !== cosCache.dims[1] && rotaryEmbeddingDim / 2 !== cosCache.dims[1]) {
throw new Error(
`Input 'cos_cache' dimension 1 should be same as head_size / 2 or rotary_embedding_dim / 2, got ${
cosCache.dims[1]
}`,
);
}

if (sequenceLength > maxSequenceLength) {
throw new Error('Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported');
}
};

export const createRotaryEmbeddingProgramInfo = (
Expand Down
73 changes: 54 additions & 19 deletions onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,28 @@ Status RotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const {
" if (global_idx >= size) { return; }\n"
" if (bsnh[3] < half_rotary_emb_dim) {\n"
<< " let position_ids_idx = " << position_ids.BroadcastedIndicesToOffset("bsnh.xy", output_indices) << ";\n"
<< " let position_id = u32(" << position_ids.GetByOffset("position_ids_idx") << ") + select(0, bsnh[1], position_ids_idx == 0);\n"
<< " let raw_pos = " << position_ids.GetByOffset("position_ids_idx") << ";\n"
Comment thread
tianleiwu marked this conversation as resolved.
<< " let i = dot(bsnh, uniforms.input_output_stride) + select(0, bsnh[3], " << interleaved_str << ");\n"
<< " let j = i + select(half_rotary_emb_dim, 1, " << interleaved_str << ");\n"
<< " let re = " << input.GetByOffset("i") << " * " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << " - " << input.GetByOffset("j") << " * " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
<< " " << output.SetByOffset("i", "re") << "\n"
<< " let im = " << input.GetByOffset("i") << " * " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << " + " << input.GetByOffset("j") << " * " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
<< " " << output.SetByOffset("j", "im") << "\n"
" let max_position = uniforms.cos_cache_shape[0];\n"
// Bounds check: raw_pos < 0 catches negative position_ids (i32 from truncated int64).
// After u32 conversion + offset, check >= max_position catches too-large values.
// On OOB, pass through input unchanged (same as CUDA kernel behavior).
" if (raw_pos < 0) {\n"
<< " " << output.SetByOffset("i", input.GetByOffset("i")) << "\n"
<< " " << output.SetByOffset("j", input.GetByOffset("j")) << "\n"
" } else {\n"
" let position_id = u32(raw_pos) + select(0, bsnh[1], position_ids_idx == 0);\n"
" if (position_id >= max_position) {\n"
<< " " << output.SetByOffset("i", input.GetByOffset("i")) << "\n"
<< " " << output.SetByOffset("j", input.GetByOffset("j")) << "\n"
" } else {\n"
<< " let re = " << input.GetByOffset("i") << " * " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << " - " << input.GetByOffset("j") << " * " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
<< " " << output.SetByOffset("i", "re") << "\n"
<< " let im = " << input.GetByOffset("i") << " * " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << " + " << input.GetByOffset("j") << " * " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
<< " " << output.SetByOffset("j", "im") << "\n"
" }\n"
" }\n"
<< " } else { \n"
" let k = dot(bsnh, uniforms.input_output_stride) + half_rotary_emb_dim;\n"
<< " " << output.SetByOffset("k", input.GetByOffset("k")) << "\n"
Expand Down Expand Up @@ -74,24 +89,39 @@ Status FusedQKRotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) c
<< " let seqlen = u32(seqlen_i);\n"
<< " let total_seqlen = seqlen + 1u;\n"
<< " let past_seqlen = total_seqlen - uniforms.q_global_shape[1];\n"
// position_id is derived from past_seqlen + sequence_idx (always non-negative).
<< " let position_id = past_seqlen + sequence_idx;\n"
<< " let cos_v = " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
<< " let sin_v = " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
<< " let qi = dot(bsnh, uniforms.q_input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n"
<< " let qj = qi + select(half_rotary_dim, 1u, " << interleaved_str << ");\n"
<< " let q_re = " << q_input.GetByOffset("qi") << " * cos_v - " << q_input.GetByOffset("qj") << " * sin_v;\n"
<< " " << q_output.SetByOffset("qi", "q_re") << "\n"
<< " let q_im = " << q_input.GetByOffset("qi") << " * sin_v + " << q_input.GetByOffset("qj") << " * cos_v;\n"
<< " " << q_output.SetByOffset("qj", "q_im") << "\n"
// Bounds check: position_id must be within cos/sin cache range.
// On OOB, pass through input unchanged (same as CUDA kernel behavior).
" let max_position = uniforms.cos_cache_shape[0];\n"
" if (position_id >= max_position) {\n"
<< " " << q_output.SetByOffset("qi", q_input.GetByOffset("qi")) << "\n"
<< " " << q_output.SetByOffset("qj", q_input.GetByOffset("qj")) << "\n"
<< " if (bsnh[2] < uniforms.k_global_shape[2]) {\n"
<< " let ki = dot(bsnh, uniforms.k_input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n"
<< " let kj = ki + select(half_rotary_dim, 1u, " << interleaved_str << ");\n"
<< " " << k_output.SetByOffset("ki", k_input.GetByOffset("ki")) << "\n"
<< " " << k_output.SetByOffset("kj", k_input.GetByOffset("kj")) << "\n"
" }\n"
" } else {\n"
<< " let cos_v = " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
<< " let sin_v = " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
<< " let q_re = " << q_input.GetByOffset("qi") << " * cos_v - " << q_input.GetByOffset("qj") << " * sin_v;\n"
<< " " << q_output.SetByOffset("qi", "q_re") << "\n"
<< " let q_im = " << q_input.GetByOffset("qi") << " * sin_v + " << q_input.GetByOffset("qj") << " * cos_v;\n"
<< " " << q_output.SetByOffset("qj", "q_im") << "\n"
// Conditionally process Key (only for heads that exist in K domain)
<< " if (bsnh[2] < uniforms.k_global_shape[2]) {\n"
<< " let ki = dot(bsnh, uniforms.k_input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n"
<< " let kj = ki + select(half_rotary_dim, 1u, " << interleaved_str << ");\n"
<< " let k_re = " << k_input.GetByOffset("ki") << " * cos_v - " << k_input.GetByOffset("kj") << " * sin_v;\n"
<< " " << k_output.SetByOffset("ki", "k_re") << "\n"
<< " let k_im = " << k_input.GetByOffset("ki") << " * sin_v + " << k_input.GetByOffset("kj") << " * cos_v;\n"
<< " " << k_output.SetByOffset("kj", "k_im") << "\n"
<< " }\n"
<< " if (bsnh[2] < uniforms.k_global_shape[2]) {\n"
<< " let ki = dot(bsnh, uniforms.k_input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n"
<< " let kj = ki + select(half_rotary_dim, 1u, " << interleaved_str << ");\n"
<< " let k_re = " << k_input.GetByOffset("ki") << " * cos_v - " << k_input.GetByOffset("kj") << " * sin_v;\n"
<< " " << k_output.SetByOffset("ki", "k_re") << "\n"
<< " let k_im = " << k_input.GetByOffset("ki") << " * sin_v + " << k_input.GetByOffset("kj") << " * cos_v;\n"
<< " " << k_output.SetByOffset("kj", "k_im") << "\n"
" }\n"
" }\n"
<< " } else {\n"
<< " let qk = dot(bsnh, uniforms.q_input_output_stride) + half_rotary_dim;\n"
<< " " << q_output.SetByOffset("qk", q_input.GetByOffset("qk")) << "\n"
Expand Down Expand Up @@ -127,6 +157,11 @@ Status RotaryEmbedding::ComputeInternal(onnxruntime::webgpu::ComputeContext& con
const auto half_rotary_embedding_dim = onnxruntime::narrow<uint32_t>(cos_cache->Shape()[1]);
const auto head_size = rotary_embedding_dim_ == 0 ? half_rotary_embedding_dim * 2 : hidden_size / num_heads_;

Comment thread
titaiwangms marked this conversation as resolved.
// position_ids bounds validation is handled by shader-side defense-in-depth checks
// (OOB position_ids → pass-through input unchanged). Host-side value scanning is not possible
// because WebGPU program inputs must be GPU buffers (InputMemoryType(OrtMemTypeCPUInput) is
// incompatible with AddInputs).

// Rotary embeddings will be calculated in a pair-wise fashion. In accordance, use the shape
// [batch size, sequence length, num of heads, num of pairs to rotate + num of dims to copy]
// to unfold the global index in shader.
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/providers/webgpu/llm/rotary_embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ Status RotaryEmbedding::ComputeInternal(ComputeContext& context) const {

if (position_ids != nullptr) {
// position_ids provided: cos/sin cache is 2D (max_pos, D/2)
// position_ids bounds validation is handled by shader-side defense-in-depth checks
// (OOB position_ids → pass-through input unchanged). Host-side value scanning is not possible
// because WebGPU program inputs must be GPU buffers (InputMemoryType(OrtMemTypeCPUInput) is
// incompatible with AddInputs).
// Note: ONNX RotaryEmbedding has no base-offset mode (format 0) — position_ids is always
// a 2D tensor (batch_size, sequence_length) when provided.

contrib::webgpu::RotaryEmbeddingProgram program{interleaved_};
program
.CacheHint(interleaved_)
Expand Down
114 changes: 114 additions & 0 deletions onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1054,5 +1054,119 @@ TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_RejectsRank4MalformedCacheWidth
{}, nullptr, &execution_providers);
}

// Test that OOB position_ids on WebGPU (format 1) pass through input unchanged (shader-side defense).
TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_PositionIds_OOB_WebGPU_Passthrough) {
if (nullptr == DefaultWebGpuExecutionProvider().get()) {
GTEST_SKIP() << "WebGPU execution provider is not available.";
}

int batch_size = 1;
int sequence_length = 2;
Comment thread
tianleiwu marked this conversation as resolved.
int num_heads = 2;
int head_size = 4;
int max_sequence_length = 8;
int hidden_size = num_heads * head_size;

OpTester test("RotaryEmbedding", 1, onnxruntime::kMSDomain);
test.AddAttribute<int64_t>("interleaved", static_cast<int64_t>(0));

std::vector<float> input_data(batch_size * sequence_length * hidden_size);
for (size_t i = 0; i < input_data.size(); ++i) {
input_data[i] = static_cast<float>(i + 1);
}

test.AddInput<float>("input", {batch_size, sequence_length, hidden_size}, input_data);
// Both position_ids exceed max_sequence_length = 8 — shader passes through input unchanged.
test.AddInput<int64_t>("position_ids", {batch_size, sequence_length}, {999, 999});
test.AddInput<float>("cos_cache", {max_sequence_length, head_size / 2},
std::vector<float>(max_sequence_length * head_size / 2, 1.0f));
Comment thread
tianleiwu marked this conversation as resolved.
Outdated
test.AddInput<float>("sin_cache", {max_sequence_length, head_size / 2},
std::vector<float>(max_sequence_length * head_size / 2, 0.0f));

// Output should equal input when position_id is OOB (pass-through).
test.AddOutput<float>("output", {batch_size, sequence_length, hidden_size}, input_data);
test.SetOutputAbsErr("output", 0.0f);

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultWebGpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

// Test that format-0 OOB position_ids base offset passes through on WebGPU (shader-side defense).
TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_PositionIds_Format0_OOB_WebGPU_Passthrough) {
if (nullptr == DefaultWebGpuExecutionProvider().get()) {
GTEST_SKIP() << "WebGPU execution provider is not available.";
}

int batch_size = 1;
int sequence_length = 2;
int num_heads = 2;
int head_size = 4;
int max_sequence_length = 8;
int hidden_size = num_heads * head_size;

OpTester test("RotaryEmbedding", 1, onnxruntime::kMSDomain);
test.AddAttribute<int64_t>("interleaved", static_cast<int64_t>(0));

std::vector<float> input_data(batch_size * sequence_length * hidden_size);
for (size_t i = 0; i < input_data.size(); ++i) {
input_data[i] = static_cast<float>(i + 1);
}

test.AddInput<float>("input", {batch_size, sequence_length, hidden_size}, input_data);
// Format 0: base offset 8, effective positions = [8, 9] — both OOB for max_sequence_length = 8.
test.AddInput<int64_t>("position_ids", {1}, {8});
test.AddInput<float>("cos_cache", {max_sequence_length, head_size / 2},
std::vector<float>(max_sequence_length * head_size / 2, 1.0f));
test.AddInput<float>("sin_cache", {max_sequence_length, head_size / 2},
std::vector<float>(max_sequence_length * head_size / 2, 0.0f));

// Output should equal input when all positions are OOB (pass-through).
test.AddOutput<float>("output", {batch_size, sequence_length, hidden_size}, input_data);
test.SetOutputAbsErr("output", 0.0f);

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultWebGpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

// Test that negative position_ids pass through on WebGPU (shader-side defense catches raw_pos < 0).
TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_PositionIds_Negative_WebGPU_Passthrough) {
if (nullptr == DefaultWebGpuExecutionProvider().get()) {
GTEST_SKIP() << "WebGPU execution provider is not available.";
}

int batch_size = 1;
int sequence_length = 1;
int num_heads = 2;
int head_size = 4;
int max_sequence_length = 8;
int hidden_size = num_heads * head_size;

OpTester test("RotaryEmbedding", 1, onnxruntime::kMSDomain);
test.AddAttribute<int64_t>("interleaved", static_cast<int64_t>(0));

std::vector<float> input_data(hidden_size);
for (int i = 0; i < hidden_size; ++i) {
input_data[i] = static_cast<float>(i + 1);
}

test.AddInput<float>("input", {batch_size, sequence_length, hidden_size}, input_data);
// Negative position_id — shader checks raw_pos < 0 and passes through.
test.AddInput<int64_t>("position_ids", {batch_size, sequence_length}, {-5});
test.AddInput<float>("cos_cache", {max_sequence_length, head_size / 2},
std::vector<float>(max_sequence_length * head_size / 2, 1.0f));
test.AddInput<float>("sin_cache", {max_sequence_length, head_size / 2},
std::vector<float>(max_sequence_length * head_size / 2, 0.0f));

// Output should equal input when position_id is negative (pass-through).
test.AddOutput<float>("output", {batch_size, sequence_length, hidden_size}, input_data);
test.SetOutputAbsErr("output", 0.0f);

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultWebGpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

} // namespace test
} // namespace onnxruntime
Loading
Loading