Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][WebGPU/JSEP] Support group query attention do_rotary attribute #23524

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
114 changes: 106 additions & 8 deletions js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@

import { TensorView } from '../../tensor-view';
import { createAttributeWithCacheKey } from '../attribute-with-cache-key';
import { ComputeContext } from '../types';
import { ComputeContext, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';
import { DataType } from '../../../wasm-common';

import { applyAttention, AttentionMaskType, AttentionParameters, AttentionQkvFormat } from './attention';
import { maybeTransposeToBNSHAndAddBias } from './multihead-attention';
import { createSplitProgramInfo, SplitAttributes } from './split';
import { createTransposeProgramInfo, TransposeAttributes } from './transpose';
import { RotaryEmbeddingAttributes, createRotaryEmbeddingProgramInfo } from './rotary-embedding';
import { inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from './common';
export interface GroupQueryAttentionAttributes {
numHeads: number;
kvNumHeads: number;
Expand All @@ -24,9 +27,6 @@ export const validateInputs = (
inputs: readonly TensorView[],
attributes: GroupQueryAttentionAttributes,
): AttentionParameters => {
if (attributes.doRotary) {
throw new Error('GroupQuerryAttention do_rotary attribute is not supported');
}
if (attributes.doRotary && inputs.length <= 7) {
throw new Error('cos_cache and sin_cache inputs are required if do_rotary is specified');
}
Expand All @@ -35,6 +35,9 @@ export const validateInputs = (
const value = inputs[2];
const pastKey = inputs[3];
const pastValue = inputs[4];
if (attributes.doRotary !== 0 && inputs.length <= 7) {
throw new Error('cos_cast and sin_cache are expected if do_rotary attribute is non-zero');
}
if (attributes.localWindowSize !== -1) {
throw new Error('Local attention is not supported');
}
Expand Down Expand Up @@ -238,6 +241,66 @@ const maybeTransposeToBNSH = (context: ComputeContext, input: TensorView, params
return reshapedInput;
};

const generatePositionIdsProgramInfo = (
batchSize: number,
sequenceLength: number,
seqLens: TensorView,
totalSeqLen: TensorView,
) => {
const outputDataType = DataType.int64;
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
const outputShape = [batchSize * sequenceLength];
const outputSize = batchSize * sequenceLength;
const programUniforms: ProgramUniform[] = [
{ type: DataType.uint32, data: outputSize },
{ type: DataType.uint32, data: sequenceLength },
];
const getShaderSource = (shaderHelper: ShaderHelper) => {
const seqLensInputHelper = inputVariable('seq_lens', seqLens.dataType, seqLens.dims);
const totalSeqLenInputHelper = inputVariable('total_seq_lens', totalSeqLen.dataType, totalSeqLen.dims);
const positionIdsHelper = outputVariable('pos_ids', outputDataType, outputShape);

const uniforms: UniformsArrayType = [
{ name: 'output_size', type: 'u32' },
{ name: 'sequence_length', type: 'u32' },
];

return `
${shaderHelper.registerUniforms(uniforms).declareVariables(seqLensInputHelper, totalSeqLenInputHelper, positionIdsHelper)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
let total_sequence_length = u32(${totalSeqLenInputHelper.getByOffset('0')});
let is_subsequent_prompt = uniforms.sequence_length > 1 && uniforms.sequence_length != total_sequence_length;
let is_first_prompt = !is_subsequent_prompt && uniforms.sequence_length == total_sequence_length;
let batch_idx = global_idx / uniforms.sequence_length;
let sequence_idx = i32(global_idx % uniforms.sequence_length);
var pos_id: u32 = 0u;
if (is_first_prompt == false) {
let total_seqlen = ${seqLensInputHelper.getByOffset('batch_idx')} + 1;
let past_seqlen = total_seqlen - i32(uniforms.sequence_length);
if (past_seqlen + sequence_idx < total_seqlen) {
// sign extend value and convert to vec2<u32>
pos_id = u32(past_seqlen + sequence_idx);
} else {
pos_id = 1u;
}
}
${positionIdsHelper.setByOffset('global_idx', 'pos_id')}
}
`;
};
return {
name: 'GeneratePositionIds',
shaderCache: { hint: `${batchSize};${sequenceLength}`, inputDependencies },
getRunData: () => ({
outputs: [{ dims: outputShape, dataType: outputDataType }],
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
programUniforms,
}),
getShaderSource,
};
};

export const groupQueryAttention = (context: ComputeContext, attributes: GroupQueryAttentionAttributes): void => {
const params = validateInputs(context.inputs, attributes);
if (context.inputs[0].dims.length === 5) {
Expand Down Expand Up @@ -268,22 +331,57 @@ export const groupQueryAttention = (context: ComputeContext, attributes: GroupQu
!k && !v
? context.compute(createSplitProgramInfo([q], splitAttributes), { inputs: [q], outputs: [-1, -1, -1] })
: [q, k!, v!];

let qRotary: TensorView | undefined;
let kRotary: TensorView | undefined;
if (attributes.doRotary) {
const posIds = context.compute(
generatePositionIdsProgramInfo(params.batchSize, params.sequenceLength, seqLens!, totalSequenceLengthInput!),
{ inputs: [seqLens!, totalSequenceLengthInput!], outputs: [-1] },
)[0];
const cosCache = context.inputs[7];
const sinCache = context.inputs[8];
const qRotaryEmbeddingAttributes: RotaryEmbeddingAttributes = createAttributeWithCacheKey({
interleaved: attributes.rotaryInterleaved !== 0,
numHeads: params.numHeads,
rotaryEmbeddingDim: 0,
scale: attributes.scale,
});
const inputs = [query, posIds, cosCache, sinCache];
const outputs = [-1];
qRotary = context.compute(createRotaryEmbeddingProgramInfo(inputs, qRotaryEmbeddingAttributes), {
inputs,
outputs,
})[0];
inputs.splice(0, 1, key);
const kRotaryEmbeddingAttributes: RotaryEmbeddingAttributes = createAttributeWithCacheKey({
interleaved: attributes.rotaryInterleaved !== 0,
numHeads: params.kvNumHeads!,
rotaryEmbeddingDim: 0,
scale: attributes.scale,
});
kRotary = context.compute(createRotaryEmbeddingProgramInfo(inputs, kRotaryEmbeddingAttributes), {
inputs,
outputs,
})[0];
}
const Q = maybeTransposeToBNSHAndAddBias(
context,
params.batchSize,
params.numHeads,
params.sequenceLength,
params.headSize,
query,
attributes.doRotary ? qRotary! : query,
undefined,
0,
);
const K = maybeTransposeToBNSH(context, attributes.doRotary ? kRotary! : key, params);
const V = maybeTransposeToBNSH(context, value, params);

applyAttention(
context,
Q,
maybeTransposeToBNSH(context, key, params),
maybeTransposeToBNSH(context, value, params),
K,
V,
undefined,
undefined,
pastKey,
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ const validateInputs = (inputs: readonly TensorView[], attributes: RotaryEmbeddi
}
};

const createRotaryEmbeddingProgramInfo = (
export const createRotaryEmbeddingProgramInfo = (
inputs: readonly TensorView[],
attributes: RotaryEmbeddingAttributes,
): ProgramInfo => {
Expand Down
Loading