Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
peishenyan committed Feb 10, 2025
1 parent f1abf1b commit 203a7c7
Showing 1 changed file with 35 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,11 @@ Status AttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co
std::vector<uint32_t> reshape_output_shape = {qkv_batch_size, qkv_sequence_length, qkv_hidden_size};
std::vector<uint32_t> scatter_indices_shape = {qkv_batch_size, kv_num_heads, qkv_sequence_length, 3};
std::vector<uint32_t> reshape_tensor_shape = {qkv_batch_size, qkv_sequence_length, num_heads, head_size};
std::vector<uint32_t> group_broadcast_tensor_shape_1 = {qkv_batch_size, past_sequence_length, kv_num_heads, 1,
std::vector<uint32_t> group_broadcast_tensor_shape_1 = {qkv_batch_size, kv_num_heads, 1, past_sequence_length,
head_size};
std::vector<uint32_t> group_broadcast_tensor_shape_2 = {qkv_batch_size, past_sequence_length, kv_num_heads,
group_size, head_size};
std::vector<uint32_t> group_broadcast_tensor_shape_3 = {qkv_batch_size, past_sequence_length, num_heads, head_size};
std::vector<uint32_t> group_broadcast_tensor_shape_2 = {qkv_batch_size, kv_num_heads, group_size,
past_sequence_length, head_size};
std::vector<uint32_t> group_broadcast_tensor_shape_3 = {qkv_batch_size, num_heads, past_sequence_length, head_size};

emscripten::val common_options = emscripten::val::object();
if (input_defs[0]->Type() == onnx::Utils::DataTypeUtils::ToType("float16")) {
Expand Down Expand Up @@ -192,6 +192,7 @@ Status AttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co
emscripten::val right_constant =
model_builder.GetBuilder().call<emscripten::val>("constant", desc_right, right_buffer);

// query_input -> reshape(B,S,N,H) -> transpose(B,N,S,H) -> new_query
common_options.set("label", node.Name() + "/GQA/query/reshape");
emscripten::val reshaped_query = model_builder.GetBuilder().call<emscripten::val>(
"reshape", query_input, emscripten::val::array(reshape_tensor_shape), common_options);
Expand Down Expand Up @@ -268,29 +269,36 @@ Status AttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co
emscripten::val present_value = model_builder.GetBuilder().call<emscripten::val>(
"scatterND", past_value_input, scatter_indices_casted, value_for_scatter, common_options);

// present_key/value(B,P,kv_N,H) -> reshape(B,P,kv_N,1,H) -> expand(B,P,kv_N,N/kv_N,H) -> reshape(B,P,N,H) ->
// true_present_key/value
common_options.set("label", node.Name() + "/GQA/true_present_key/reshape_1");
emscripten::val true_present_key = model_builder.GetBuilder().call<emscripten::val>(
"reshape", present_key, emscripten::val::array(group_broadcast_tensor_shape_1), common_options);
common_options.set("label", node.Name() + "/GQA/true_present_key/expand");
true_present_key = model_builder.GetBuilder().call<emscripten::val>(
"expand", true_present_key, emscripten::val::array(group_broadcast_tensor_shape_2), common_options);
common_options.set("label", node.Name() + "/GQA/true_present_key/reshape_2");
true_present_key = model_builder.GetBuilder().call<emscripten::val>(
"reshape", true_present_key, emscripten::val::array(group_broadcast_tensor_shape_3), common_options);

common_options.set("label", node.Name() + "/GQA/true_present_value/reshape_1");
emscripten::val true_present_value = model_builder.GetBuilder().call<emscripten::val>(
"reshape", present_value, emscripten::val::array(group_broadcast_tensor_shape_1), common_options);
common_options.set("label", node.Name() + "/GQA/true_present_value/expand");
true_present_value = model_builder.GetBuilder().call<emscripten::val>(
"expand", true_present_value, emscripten::val::array(group_broadcast_tensor_shape_2), common_options);
common_options.set("label", node.Name() + "/GQA/true_present_value/reshape_2");
true_present_value = model_builder.GetBuilder().call<emscripten::val>(
"reshape", true_present_value, emscripten::val::array(group_broadcast_tensor_shape_3), common_options);

// true_present_key(B,P,N,H) -> transpose(B,P,H,N)
emscripten::val true_present_key;
emscripten::val true_present_value;
if (group_size != 1) {
// present_key/value(B,kv_N,P,H) -> reshape(B,kv_N,1,P,H) -> expand(B,kv_N,N/kv_N,P,H) -> reshape(B,N,P,H) ->
// true_present_key/value
common_options.set("label", node.Name() + "/GQA/true_present_key/reshape_1");
true_present_key = model_builder.GetBuilder().call<emscripten::val>(
"reshape", present_key, emscripten::val::array(group_broadcast_tensor_shape_1), common_options);
common_options.set("label", node.Name() + "/GQA/true_present_key/expand");
true_present_key = model_builder.GetBuilder().call<emscripten::val>(
"expand", true_present_key, emscripten::val::array(group_broadcast_tensor_shape_2), common_options);
common_options.set("label", node.Name() + "/GQA/true_present_key/reshape_2");
true_present_key = model_builder.GetBuilder().call<emscripten::val>(
"reshape", true_present_key, emscripten::val::array(group_broadcast_tensor_shape_3), common_options);

common_options.set("label", node.Name() + "/GQA/true_present_value/reshape_1");
true_present_value = model_builder.GetBuilder().call<emscripten::val>(
"reshape", present_value, emscripten::val::array(group_broadcast_tensor_shape_1), common_options);
common_options.set("label", node.Name() + "/GQA/true_present_value/expand");
true_present_value = model_builder.GetBuilder().call<emscripten::val>(
"expand", true_present_value, emscripten::val::array(group_broadcast_tensor_shape_2), common_options);
common_options.set("label", node.Name() + "/GQA/true_present_value/reshape_2");
true_present_value = model_builder.GetBuilder().call<emscripten::val>(
"reshape", true_present_value, emscripten::val::array(group_broadcast_tensor_shape_3), common_options);
} else { // no need for broadcast
true_present_key = present_key;
true_present_value = present_value;
}

// true_present_key(B,N,P,H) -> transpose(B,N,H,P)
options.set("permutation", emscripten::val::array(std::vector<uint32_t>({0, 1, 3, 2})));
options.set("label", node.Name() + "/GQA/present_key/transpose");
true_present_key = model_builder.GetBuilder().call<emscripten::val>("transpose", true_present_key, options);
Expand Down Expand Up @@ -471,8 +479,6 @@ bool AttentionOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* i
if (input_defs.size() < 7) {
LOGS(logger, VERBOSE) << op_type << " requires at least seven inputs.";
return false;
} else {
LOGS(logger, VERBOSE) << op_type << " has inputs size: " << input_defs.size();
}

return true;
Expand Down

0 comments on commit 203a7c7

Please sign in to comment.