Skip to content

Commit

Permalink
fix wrong variable name
Browse files Browse the repository at this point in the history
  • Loading branch information
peishenyan committed Feb 10, 2025
1 parent 9142044 commit f1abf1b
Showing 1 changed file with 3 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ Status AttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co
emscripten::val present_value = model_builder.GetBuilder().call<emscripten::val>(
"scatterND", past_value_input, scatter_indices_casted, value_for_scatter, common_options);

// TODO: present_key/value(B,P,kv_N,H) -> reshape(B,P,kv_N,1,H) -> expand(B,P,kv_N,N/kv_N,H) -> reshape(B,P,N,H) ->
// present_key/value(B,P,kv_N,H) -> reshape(B,P,kv_N,1,H) -> expand(B,P,kv_N,N/kv_N,H) -> reshape(B,P,N,H) ->
// true_present_key/value
common_options.set("label", node.Name() + "/GQA/true_present_key/reshape_1");
emscripten::val true_present_key = model_builder.GetBuilder().call<emscripten::val>(
Expand All @@ -290,9 +290,10 @@ Status AttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co
true_present_value = model_builder.GetBuilder().call<emscripten::val>(
"reshape", true_present_value, emscripten::val::array(group_broadcast_tensor_shape_3), common_options);

// true_present_key(B,P,N,H) -> transpose(B,P,H,N)
options.set("permutation", emscripten::val::array(std::vector<uint32_t>({0, 1, 3, 2})));
options.set("label", node.Name() + "/GQA/present_key/transpose");
true_present_key = model_builder.GetBuilder().call<emscripten::val>("transpose", present_key, options);
true_present_key = model_builder.GetBuilder().call<emscripten::val>("transpose", true_present_key, options);

common_options.set("label", node.Name() + "/GQA/qkv/matmul_1");
emscripten::val matmul_output =
Expand Down

0 comments on commit f1abf1b

Please sign in to comment.