From 361df45c11f02a6cce5af4fcd326cc43bf69041e Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 23 Oct 2024 09:18:59 -0700 Subject: [PATCH 1/9] Added do_rotary attribute support to GQA. --- .../jsep/webgpu/ops/group-query-attention.ts | 101 +++++++++++++++++- .../wasm/jsep/webgpu/ops/rotary-embedding.ts | 2 +- 2 files changed, 98 insertions(+), 5 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts index bbe25460d6fd3..5862925c1a5cb 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts @@ -3,12 +3,15 @@ import { TensorView } from '../../tensor-view'; import { createAttributeWithCacheKey } from '../attribute-with-cache-key'; -import { ComputeContext } from '../types'; +import { ComputeContext, ProgramInputTensorInfoDependency, ProgramUniform } from '../types'; +import { DataType } from '../../../wasm-common'; import { applyAttention, AttentionMaskType, AttentionParameters, AttentionQkvFormat } from './attention'; import { maybeTransposeToBNSHAndAddBias } from './multihead-attention'; import { createSplitProgramInfo, SplitAttributes } from './split'; import { createTransposeProgramInfo, TransposeAttributes } from './transpose'; +import { RotaryEmbeddingAttributes, createRotaryEmbeddingProgramInfo } from './rotary-embedding'; +import { inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType } from './common'; export interface GroupQueryAttentionAttributes { numHeads: number; kvNumHeads: number; @@ -32,6 +35,9 @@ export const validateInputs = ( const value = inputs[2]; const pastKey = inputs[3]; const pastValue = inputs[4]; + if (attributes.doRotary !== 0 && inputs.length <= 7) { + throw new Error('cos_cast and sin_cache are expected if do_rotary attribute is non-zero'); + } if (attributes.localWindowSize !== -1) { throw new Error('Local attention is not supported'); } @@ -235,6 +241,66 @@ const maybeTransposeToBNSH = (context: ComputeContext, input: TensorView, params return reshapedInput; }; +const generatePositionIdsProgramInfo = ( + batchSize: number, + sequenceLength: number, + seqLens: TensorView, + totalSeqLen: TensorView, +) => { + const outputDataType = DataType.int64; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; + const outputShape = [batchSize * sequenceLength]; + const outputSize = batchSize * sequenceLength; + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize }, + { type: DataType.uint32, data: sequenceLength }, + ]; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const seqLensInputHelper = inputVariable('seq_lens', seqLens.dataType, seqLens.dims); + const totalSeqLenInputHelper = inputVariable('total_seq_lens', totalSeqLen.dataType, totalSeqLen.dims); + const positionIdsHelper = outputVariable('pos_ids', outputDataType, outputShape); + + const uniforms: UniformsArrayType = [ + { name: 'output_size', type: 'u32' }, + { name: 'sequence_length', type: 'u32' }, + ]; + const outputType = tensorTypeToWsglStorageType(outputDataType); + + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(seqLensInputHelper, totalSeqLenInputHelper, positionIdsHelper)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + let total_sequence_length = u32(${totalSeqLenInputHelper.getByOffset('0')}); + let is_subsequent_prompt = uniforms.sequence_length > 1 && uniforms.sequence_length != total_sequence_length; + let is_first_prompt = !is_subsequent_prompt && uniforms.sequence_length == total_sequence_length; + let batch_idx = global_idx / uniforms.sequence_length; + let sequence_idx = global_idx % uniforms.sequence_length; + var pos_id: ${outputType} = ${outputType}(0); + if (is_first_prompt == false) { + let total_seqlen = u32(${seqLensInputHelper.getByOffset('batch_idx')}) + 1u; + let past_seqlen = u32(total_seqlen - uniforms.sequence_length); + if (past_seqlen + sequence_idx < total_seqlen) { + pos_id = ${outputType}(past_seqlen + sequence_idx); + } else { + pos_id = ${outputType}(1); + } + } + pos_ids[global_idx] = pos_id; + } + `; + }; + return { + name: 'GeneratePositionIds', + shaderCache: { hint: `${batchSize}`, inputDependencies }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: outputDataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, + }), + getShaderSource, + }; +}; + export const groupQueryAttention = (context: ComputeContext, attributes: GroupQueryAttentionAttributes): void => { const params = validateInputs(context.inputs, attributes); if (context.inputs[0].dims.length === 5) { @@ -276,11 +342,38 @@ export const groupQueryAttention = (context: ComputeContext, attributes: GroupQu undefined, 0, ); + const K = maybeTransposeToBNSH(context, key, params); + const V = maybeTransposeToBNSH(context, value, params); + let qRotary: TensorView | undefined; + let kRotary: TensorView | undefined; + if (attributes.doRotary) { + const posIds = context.compute( + generatePositionIdsProgramInfo(params.batchSize, params.sequenceLength, seqLens!, totalSequenceLengthInput!), + { inputs: [seqLens!, totalSequenceLengthInput!], outputs: [-1] }, + )[0]; + const cosCache = context.inputs[7]; + const sinCache = context.inputs[8]; + const rotaryEmbeddingAttributes: RotaryEmbeddingAttributes = createAttributeWithCacheKey({ + interleaved: false, + numHeads: params.numHeads, + rotaryEmbeddingDim: 2 * cosCache.dims[1], + scale: attributes.scale, + }); + + qRotary = context.compute( + createRotaryEmbeddingProgramInfo([Q, posIds, cosCache, sinCache], rotaryEmbeddingAttributes), + { inputs: [Q, posIds, cosCache, sinCache], outputs: [-1] }, + )[0]; + kRotary = context.compute( + createRotaryEmbeddingProgramInfo([K, posIds, cosCache, sinCache], rotaryEmbeddingAttributes), + { inputs: [K, posIds, cosCache, sinCache], outputs: [-1] }, + )[0]; + } applyAttention( context, - Q, - maybeTransposeToBNSH(context, key, params), - maybeTransposeToBNSH(context, value, params), + attributes.doRotary ? qRotary! : Q, + attributes.doRotary ? kRotary! : K, + V, undefined, undefined, pastKey, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts b/js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts index 8eb7a10ac91fa..fe2567e71d49a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts @@ -75,7 +75,7 @@ const validateInputs = (inputs: readonly TensorView[], attributes: RotaryEmbeddi } }; -const createRotaryEmbeddingProgramInfo = ( +export const createRotaryEmbeddingProgramInfo = ( inputs: readonly TensorView[], attributes: RotaryEmbeddingAttributes, ): ProgramInfo => { From cc317ca46d138e94fa876a629bd0ff0a05dd7712 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 25 Oct 2024 12:17:00 -0700 Subject: [PATCH 2/9] Apply rotary embedding before transposing to to BNSH --- .../jsep/webgpu/ops/group-query-attention.ts | 69 ++++++++++--------- 1 file changed, 37 insertions(+), 32 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts index 5862925c1a5cb..1f808228891d9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts @@ -274,18 +274,21 @@ const generatePositionIdsProgramInfo = ( let is_subsequent_prompt = uniforms.sequence_length > 1 && uniforms.sequence_length != total_sequence_length; let is_first_prompt = !is_subsequent_prompt && uniforms.sequence_length == total_sequence_length; let batch_idx = global_idx / uniforms.sequence_length; - let sequence_idx = global_idx % uniforms.sequence_length; + let sequence_idx = i32(global_idx % uniforms.sequence_length); var pos_id: ${outputType} = ${outputType}(0); if (is_first_prompt == false) { - let total_seqlen = u32(${seqLensInputHelper.getByOffset('batch_idx')}) + 1u; - let past_seqlen = u32(total_seqlen - uniforms.sequence_length); + let total_seqlen = ${seqLensInputHelper.getByOffset('batch_idx')} + 1; + let past_seqlen = total_seqlen - i32(uniforms.sequence_length); if (past_seqlen + sequence_idx < total_seqlen) { - pos_id = ${outputType}(past_seqlen + sequence_idx); + // sign extend value and convert to vec2 + let value = past_seqlen + sequence_idx; + let sign_ext = select(0, 0xFFFFFFF, value < 0); + pos_id = ${outputType}(u32(sign_ext), u32(value)); } else { - pos_id = ${outputType}(1); + pos_id = ${outputType}(0,1); } } - pos_ids[global_idx] = pos_id; + ${positionIdsHelper.setByOffset('global_idx', 'pos_id')} } `; }; @@ -331,19 +334,6 @@ export const groupQueryAttention = (context: ComputeContext, attributes: GroupQu !k && !v ? context.compute(createSplitProgramInfo([q], splitAttributes), { inputs: [q], outputs: [-1, -1, -1] }) : [q, k!, v!]; - - const Q = maybeTransposeToBNSHAndAddBias( - context, - params.batchSize, - params.numHeads, - params.sequenceLength, - params.headSize, - query, - undefined, - 0, - ); - const K = maybeTransposeToBNSH(context, key, params); - const V = maybeTransposeToBNSH(context, value, params); let qRotary: TensorView | undefined; let kRotary: TensorView | undefined; if (attributes.doRotary) { @@ -354,25 +344,40 @@ export const groupQueryAttention = (context: ComputeContext, attributes: GroupQu const cosCache = context.inputs[7]; const sinCache = context.inputs[8]; const rotaryEmbeddingAttributes: RotaryEmbeddingAttributes = createAttributeWithCacheKey({ - interleaved: false, + interleaved: attributes.rotaryInterleaved !== 0, numHeads: params.numHeads, - rotaryEmbeddingDim: 2 * cosCache.dims[1], + rotaryEmbeddingDim: 0, scale: attributes.scale, }); - - qRotary = context.compute( - createRotaryEmbeddingProgramInfo([Q, posIds, cosCache, sinCache], rotaryEmbeddingAttributes), - { inputs: [Q, posIds, cosCache, sinCache], outputs: [-1] }, - )[0]; - kRotary = context.compute( - createRotaryEmbeddingProgramInfo([K, posIds, cosCache, sinCache], rotaryEmbeddingAttributes), - { inputs: [K, posIds, cosCache, sinCache], outputs: [-1] }, - )[0]; + const inputs = [query, posIds, cosCache, sinCache]; + const outputs = [-1]; + qRotary = context.compute(createRotaryEmbeddingProgramInfo(inputs, rotaryEmbeddingAttributes), { + inputs, + outputs, + })[0]; + inputs.splice(0, 1, key); + kRotary = context.compute(createRotaryEmbeddingProgramInfo(inputs, rotaryEmbeddingAttributes), { + inputs, + outputs, + })[0]; } + const Q = maybeTransposeToBNSHAndAddBias( + context, + params.batchSize, + params.numHeads, + params.sequenceLength, + params.headSize, + attributes.doRotary ? qRotary! : query, + undefined, + 0, + ); + const K = maybeTransposeToBNSH(context, attributes.doRotary ? kRotary! : key, params); + const V = maybeTransposeToBNSH(context, value, params); + applyAttention( context, - attributes.doRotary ? qRotary! : Q, - attributes.doRotary ? kRotary! : K, + Q, + K, V, undefined, undefined, From 7952ddb9c2b2af0882bb91e846245920d049a42f Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Mon, 28 Oct 2024 19:26:52 -0700 Subject: [PATCH 3/9] minor changes. --- js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts index 1f808228891d9..b9a10d0cb3eca 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts @@ -282,8 +282,7 @@ const generatePositionIdsProgramInfo = ( if (past_seqlen + sequence_idx < total_seqlen) { // sign extend value and convert to vec2 let value = past_seqlen + sequence_idx; - let sign_ext = select(0, 0xFFFFFFF, value < 0); - pos_id = ${outputType}(u32(sign_ext), u32(value)); + pos_id = ${outputType}(u32(extractBits(value, 31, 1)), u32(value)); } else { pos_id = ${outputType}(0,1); } @@ -376,8 +375,8 @@ export const groupQueryAttention = (context: ComputeContext, attributes: GroupQu applyAttention( context, - Q, - K, + attributes.doRotary ? qRotary! : Q, + attributes.doRotary ? kRotary! : K, V, undefined, undefined, From a95c3d85ee8244ef5e3212388940f41cea5ac511 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Mon, 13 Jan 2025 10:03:11 -0800 Subject: [PATCH 4/9] A fixed the pos_id type. --- .../jsep/webgpu/ops/group-query-attention.ts | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts index 10fd4475cc26e..33e52593397eb 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts @@ -11,7 +11,7 @@ import { maybeTransposeToBNSHAndAddBias } from './multihead-attention'; import { createSplitProgramInfo, SplitAttributes } from './split'; import { createTransposeProgramInfo, TransposeAttributes } from './transpose'; import { RotaryEmbeddingAttributes, createRotaryEmbeddingProgramInfo } from './rotary-embedding'; -import { inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType } from './common'; +import { inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from './common'; export interface GroupQueryAttentionAttributes { numHeads: number; kvNumHeads: number; @@ -27,9 +27,6 @@ export const validateInputs = ( inputs: readonly TensorView[], attributes: GroupQueryAttentionAttributes, ): AttentionParameters => { - if (attributes.doRotary) { - throw new Error('GroupQuerryAttention do_rotary attribute is not supported'); - } if (attributes.doRotary && inputs.length <= 7) { throw new Error('cos_cache and sin_cache inputs are required if do_rotary is specified'); } @@ -267,7 +264,6 @@ const generatePositionIdsProgramInfo = ( { name: 'output_size', type: 'u32' }, { name: 'sequence_length', type: 'u32' }, ]; - const outputType = tensorTypeToWsglStorageType(outputDataType); return ` ${shaderHelper.registerUniforms(uniforms).declareVariables(seqLensInputHelper, totalSeqLenInputHelper, positionIdsHelper)} @@ -278,16 +274,13 @@ const generatePositionIdsProgramInfo = ( let is_first_prompt = !is_subsequent_prompt && uniforms.sequence_length == total_sequence_length; let batch_idx = global_idx / uniforms.sequence_length; let sequence_idx = i32(global_idx % uniforms.sequence_length); - var pos_id: ${outputType} = ${outputType}(0); + var pos_id: u32 = 0u; if (is_first_prompt == false) { let total_seqlen = ${seqLensInputHelper.getByOffset('batch_idx')} + 1; let past_seqlen = total_seqlen - i32(uniforms.sequence_length); if (past_seqlen + sequence_idx < total_seqlen) { // sign extend value and convert to vec2 - let value = past_seqlen + sequence_idx; - pos_id = ${outputType}(u32(extractBits(value, 31, 1)), u32(value)); - } else { - pos_id = ${outputType}(0,1); + pos_id = u32(past_seqlen + sequence_idx); } } ${positionIdsHelper.setByOffset('global_idx', 'pos_id')} @@ -345,7 +338,7 @@ export const groupQueryAttention = (context: ComputeContext, attributes: GroupQu )[0]; const cosCache = context.inputs[7]; const sinCache = context.inputs[8]; - const rotaryEmbeddingAttributes: RotaryEmbeddingAttributes = createAttributeWithCacheKey({ + const qRotaryEmbeddingAttributes: RotaryEmbeddingAttributes = createAttributeWithCacheKey({ interleaved: attributes.rotaryInterleaved !== 0, numHeads: params.numHeads, rotaryEmbeddingDim: 0, @@ -353,12 +346,18 @@ export const groupQueryAttention = (context: ComputeContext, attributes: GroupQu }); const inputs = [query, posIds, cosCache, sinCache]; const outputs = [-1]; - qRotary = context.compute(createRotaryEmbeddingProgramInfo(inputs, rotaryEmbeddingAttributes), { + qRotary = context.compute(createRotaryEmbeddingProgramInfo(inputs, qRotaryEmbeddingAttributes), { inputs, outputs, })[0]; inputs.splice(0, 1, key); - kRotary = context.compute(createRotaryEmbeddingProgramInfo(inputs, rotaryEmbeddingAttributes), { + const kRotaryEmbeddingAttributes: RotaryEmbeddingAttributes = createAttributeWithCacheKey({ + interleaved: attributes.rotaryInterleaved !== 0, + numHeads: params.kvNumHeads!, + rotaryEmbeddingDim: 0, + scale: attributes.scale, + }); + kRotary = context.compute(createRotaryEmbeddingProgramInfo(inputs, kRotaryEmbeddingAttributes), { inputs, outputs, })[0]; From 845469b7bf858bdae96f6163f1c32ab5024eac62 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Mon, 13 Jan 2025 15:40:21 -0800 Subject: [PATCH 5/9] minor change --- js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts | 2 ++ 1 file changed, 2 insertions(+) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts index 33e52593397eb..2001176146b72 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts @@ -281,6 +281,8 @@ const generatePositionIdsProgramInfo = ( if (past_seqlen + sequence_idx < total_seqlen) { // sign extend value and convert to vec2 pos_id = u32(past_seqlen + sequence_idx); + } else { + pos_id = 1u; } } ${positionIdsHelper.setByOffset('global_idx', 'pos_id')} From d829e71de37dd35bff25ce6253f5eec94dca949e Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Thu, 16 Jan 2025 11:20:03 -0800 Subject: [PATCH 6/9] Fixed hint for generate positionIDs. --- js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts index 2001176146b72..de8e08899e685 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts @@ -291,7 +291,7 @@ const generatePositionIdsProgramInfo = ( }; return { name: 'GeneratePositionIds', - shaderCache: { hint: `${batchSize}`, inputDependencies }, + shaderCache: { hint: `${batchSize};${sequenceLength}`, inputDependencies }, getRunData: () => ({ outputs: [{ dims: outputShape, dataType: outputDataType }], dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, From 49b765e292e43979b17122888ea5b70969abec11 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Tue, 28 Jan 2025 16:27:01 -0800 Subject: [PATCH 7/9] minor bug --- js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts index de8e08899e685..64de84ebb7ac6 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts @@ -379,8 +379,8 @@ export const groupQueryAttention = (context: ComputeContext, attributes: GroupQu applyAttention( context, - attributes.doRotary ? qRotary! : Q, - attributes.doRotary ? kRotary! : K, + Q, + K, V, undefined, undefined, From ee39ae0e241fa1c9d2343ce352c8ddb0a9e672b6 Mon Sep 17 00:00:00 2001 From: SatyaKumarJ Date: Thu, 20 Feb 2025 14:22:05 -0800 Subject: [PATCH 8/9] Fixed GeneratePositionIds code. --- .../jsep/webgpu/ops/group-query-attention.ts | 27 +++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts index 64de84ebb7ac6..fe72555c40ea0 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts @@ -254,6 +254,7 @@ const generatePositionIdsProgramInfo = ( const programUniforms: ProgramUniform[] = [ { type: DataType.uint32, data: outputSize }, { type: DataType.uint32, data: sequenceLength }, + { type: DataType.uint32, data: batchSize }, ]; const getShaderSource = (shaderHelper: ShaderHelper) => { const seqLensInputHelper = inputVariable('seq_lens', seqLens.dataType, seqLens.dims); @@ -263,6 +264,7 @@ const generatePositionIdsProgramInfo = ( const uniforms: UniformsArrayType = [ { name: 'output_size', type: 'u32' }, { name: 'sequence_length', type: 'u32' }, + { name: 'batch_size', type: 'u32' }, ]; return ` @@ -274,18 +276,27 @@ const generatePositionIdsProgramInfo = ( let is_first_prompt = !is_subsequent_prompt && uniforms.sequence_length == total_sequence_length; let batch_idx = global_idx / uniforms.sequence_length; let sequence_idx = i32(global_idx % uniforms.sequence_length); - var pos_id: u32 = 0u; - if (is_first_prompt == false) { - let total_seqlen = ${seqLensInputHelper.getByOffset('batch_idx')} + 1; + var pos_id: i32 = 0; + let seqlen = ${seqLensInputHelper.getByOffset("batch_idx")}; + let total_seqlen = seqlen + 1; + if (is_first_prompt) { + if (sequence_idx < total_seqlen) { + pos_id = sequence_idx; + } else { + pos_id = 1; + } + ${positionIdsHelper.setByOffset("global_idx", "pos_id")} + } else if (is_subsequent_prompt) { let past_seqlen = total_seqlen - i32(uniforms.sequence_length); if (past_seqlen + sequence_idx < total_seqlen) { - // sign extend value and convert to vec2 - pos_id = u32(past_seqlen + sequence_idx); + pos_id = past_seqlen + sequence_idx; } else { - pos_id = 1u; + pos_id = 1; } - } - ${positionIdsHelper.setByOffset('global_idx', 'pos_id')} + ${positionIdsHelper.setByOffset("global_idx", "pos_id")} + } else if (global_idx < uniforms.batch_size) { + ${positionIdsHelper.setByOffset("global_idx", "seqlen")} + }; } `; }; From c2a83360ca70483e4dca442fe0b16480f8a254d8 Mon Sep 17 00:00:00 2001 From: SatyaKumarJ Date: Thu, 20 Feb 2025 15:21:53 -0800 Subject: [PATCH 9/9] lint --- js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts index fe72555c40ea0..32b3c54f734dc 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts @@ -277,7 +277,7 @@ const generatePositionIdsProgramInfo = ( let batch_idx = global_idx / uniforms.sequence_length; let sequence_idx = i32(global_idx % uniforms.sequence_length); var pos_id: i32 = 0; - let seqlen = ${seqLensInputHelper.getByOffset("batch_idx")}; + let seqlen = ${seqLensInputHelper.getByOffset('batch_idx')}; let total_seqlen = seqlen + 1; if (is_first_prompt) { if (sequence_idx < total_seqlen) { @@ -285,7 +285,7 @@ const generatePositionIdsProgramInfo = ( } else { pos_id = 1; } - ${positionIdsHelper.setByOffset("global_idx", "pos_id")} + ${positionIdsHelper.setByOffset('global_idx', 'pos_id')} } else if (is_subsequent_prompt) { let past_seqlen = total_seqlen - i32(uniforms.sequence_length); if (past_seqlen + sequence_idx < total_seqlen) { @@ -293,9 +293,9 @@ const generatePositionIdsProgramInfo = ( } else { pos_id = 1; } - ${positionIdsHelper.setByOffset("global_idx", "pos_id")} + ${positionIdsHelper.setByOffset('global_idx', 'pos_id')} } else if (global_idx < uniforms.batch_size) { - ${positionIdsHelper.setByOffset("global_idx", "seqlen")} + ${positionIdsHelper.setByOffset('global_idx', 'seqlen')} }; } `;