From 5e01b5d82bca933f628466c9aab85101b4dfc657 Mon Sep 17 00:00:00 2001 From: EnigmaTHU Date: Mon, 29 Sep 2025 15:19:54 +0800 Subject: [PATCH] eplb eplb_clean eplb_repair eplb_recvCount backword re re re --- csrc/deepep/ops/op_host/fused_deep_moe.cpp | 5 ++ .../ops/op_host/fused_deep_moe_infer.cpp | 46 ++++++++++++++++++- .../op_host/op_api/aclnn_fused_deep_moe.cpp | 5 +- .../ops/op_host/op_api/aclnn_fused_deep_moe.h | 3 +- csrc/deepep/ops/op_kernel/fused_deep_moe.cpp | 6 +-- csrc/deepep/ops/op_kernel/fused_deep_moe.h | 27 ++++++----- ...per_token_dequant_multistage_workspace.hpp | 4 -- ...equant_swiglu_quant_multistage_workspace.h | 31 +++++++++---- 8 files changed, 93 insertions(+), 34 deletions(-) diff --git a/csrc/deepep/ops/op_host/fused_deep_moe.cpp b/csrc/deepep/ops/op_host/fused_deep_moe.cpp index cada3a95..50cd95f9 100644 --- a/csrc/deepep/ops/op_host/fused_deep_moe.cpp +++ b/csrc/deepep/ops/op_host/fused_deep_moe.cpp @@ -59,6 +59,11 @@ class FusedDeepMoe : public OpDef .DataType({ge::DT_BF16, ge::DT_FLOAT16}) .Format({ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("ep_recv_count") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); this->Attr("group_ep").String(); this->Attr("ep_rank_size").Int(); this->Attr("ep_rank_id").Int(); diff --git a/csrc/deepep/ops/op_host/fused_deep_moe_infer.cpp b/csrc/deepep/ops/op_host/fused_deep_moe_infer.cpp index 4345c5c9..1391b054 100644 --- a/csrc/deepep/ops/op_host/fused_deep_moe_infer.cpp +++ b/csrc/deepep/ops/op_host/fused_deep_moe_infer.cpp @@ -16,14 +16,27 @@ namespace ge { constexpr uint32_t EXPAND_X_INDEX = 0; constexpr uint32_t EXPERT_IDS_INDEX = 1; constexpr uint32_t OUTPUT_X_INDEX = 0; +constexpr uint32_t OUTPUT_REC_COUNT_INDEX = 1; + +constexpr uint32_t ATTR_GROUP_EP_INDEX = 0; +constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1; +constexpr uint32_t ATTR_EP_RANK_ID_INDEX = 2; +constexpr uint32_t ATTR_MOE_EXPERT_NUM_INDEX = 3; +constexpr uint32_t ATTR_SHARE_EXPERT_NUM_INDEX = 4; +constexpr uint32_t ATTR_SHARE_EXPERT_RANK_NUM_INDEX = 5; +constexpr uint32_t ATTR_QUANT_MODE_INDEX = 6; +constexpr uint32_t ATTR_GLOBAL_BS_INDEX = 7; static ge::graphStatus InferShape(gert::InferShapeContext *context) { + const char *nodeName = context->GetNodeName(); + // infer output shape const gert::Shape *expandXShape = context->GetInputShape(EXPAND_X_INDEX); const gert::Shape *expertIdsShape = context->GetInputShape(EXPERT_IDS_INDEX); gert::Shape *expandXOutShape = context->GetOutputShape(OUTPUT_X_INDEX); - - if (expandXShape == nullptr || expertIdsShape == nullptr || expandXOutShape == nullptr) { + gert::Shape *recvCountOutShape = context->GetOutputShape(OUTPUT_REC_COUNT_INDEX); + if (expandXShape == nullptr || expertIdsShape == nullptr || expandXOutShape == nullptr || + recvCountOutShape == nullptr) { return GRAPH_FAILED; } if (expandXShape->GetDimNum() < 2 || expertIdsShape->GetDimNum() < 1) { @@ -37,6 +50,34 @@ static ge::graphStatus InferShape(gert::InferShapeContext *context) expandXOutShape->SetDim(0, bs); expandXOutShape->SetDim(1, h); + // infer recvCount shape + auto attrs = context->GetAttrs(); + OP_TILING_CHECK(attrs == nullptr, OP_LOGE(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED); + + auto epRankSizePtr = attrs->GetAttrPointer(ATTR_EP_RANK_SIZE_INDEX); + auto epRankIdPtr = attrs->GetAttrPointer(ATTR_EP_RANK_ID_INDEX); + auto moeExpertNumPtr = attrs->GetAttrPointer(ATTR_MOE_EXPERT_NUM_INDEX); + auto sharedExpertRankNumPtr = attrs->GetAttrPointer(ATTR_SHARE_EXPERT_RANK_NUM_INDEX); + + OP_TILING_CHECK(epRankIdPtr == nullptr, OP_LOGE(nodeName, "epRankIdPtr is nullptr."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(moeExpertNumPtr == nullptr, OP_LOGE(nodeName, "moeExpertNumPtr is nullptr."), + return ge::GRAPH_FAILED); + OP_TILING_CHECK(epRankSizePtr == nullptr, OP_LOGE(nodeName, "epRankSizePtr is nullptr."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(sharedExpertRankNumPtr == nullptr, OP_LOGE(nodeName, "sharedExpertRankNumPtr is nullptr."), + return ge::GRAPH_FAILED); + uint32_t epRankSize = static_cast(*epRankSizePtr); + uint32_t moeExpertNum = static_cast(*moeExpertNumPtr); + uint32_t epRankId = static_cast(*epRankIdPtr); + uint32_t sharedExpertRankNum = static_cast(*sharedExpertRankNumPtr); + + recvCountOutShape->SetDimNum(1); + bool isShareExpert = (epRankId < sharedExpertRankNum); + if (isShareExpert) { + recvCountOutShape->SetDim(0, epRankSize); + } else { + recvCountOutShape->SetDim(0, epRankSize * (moeExpertNum / (epRankSize - sharedExpertRankNum))); + } + return GRAPH_SUCCESS; } @@ -44,6 +85,7 @@ static ge::graphStatus InferDataType(gert::InferDataTypeContext *context) { const auto expandXDataType = context->GetInputDataType(EXPAND_X_INDEX); context->SetOutputDataType(OUTPUT_X_INDEX, expandXDataType); + context->SetOutputDataType(OUTPUT_REC_COUNT_INDEX, ge::DT_INT32); return ge::GRAPH_SUCCESS; } diff --git a/csrc/deepep/ops/op_host/op_api/aclnn_fused_deep_moe.cpp b/csrc/deepep/ops/op_host/op_api/aclnn_fused_deep_moe.cpp index 89a3f72e..aa11fd34 100644 --- a/csrc/deepep/ops/op_host/op_api/aclnn_fused_deep_moe.cpp +++ b/csrc/deepep/ops/op_host/op_api/aclnn_fused_deep_moe.cpp @@ -28,12 +28,13 @@ aclnnStatus aclnnFusedDeepMoeGetWorkspaceSize( const aclTensor *gmm1PermutedWeightScale, const aclTensor *gmm2Weight, const aclTensor *gmm2WeightScale, const aclTensor *expertSmoothScalesOptional, const aclTensor *expertScalesOptional, char *groupEp, int64_t epRankSize, int64_t epRankId, int64_t moeExpertNum, int64_t shareExpertNum, int64_t shareExpertRankNum, - int64_t quantMode, int64_t globalBs, const aclTensor *output, uint64_t *workspaceSize, aclOpExecutor **executor) + int64_t quantMode, int64_t globalBs, const aclTensor *output, const aclTensor *outputRecvCount, + uint64_t *workspaceSize, aclOpExecutor **executor) { return aclnnInnerFusedDeepMoeGetWorkspaceSize( x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale, gmm2Weight, gmm2WeightScale, expertSmoothScalesOptional, expertScalesOptional, groupEp, epRankSize, epRankId, moeExpertNum, shareExpertNum, - shareExpertRankNum, quantMode, globalBs, output, workspaceSize, executor); + shareExpertRankNum, quantMode, globalBs, output, outputRecvCount, workspaceSize, executor); } aclnnStatus aclnnFusedDeepMoe(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream) diff --git a/csrc/deepep/ops/op_host/op_api/aclnn_fused_deep_moe.h b/csrc/deepep/ops/op_host/op_api/aclnn_fused_deep_moe.h index d44768f4..9e4abdd4 100644 --- a/csrc/deepep/ops/op_host/op_api/aclnn_fused_deep_moe.h +++ b/csrc/deepep/ops/op_host/op_api/aclnn_fused_deep_moe.h @@ -21,7 +21,8 @@ __attribute__((visibility("default"))) aclnnStatus aclnnFusedDeepMoeGetWorkspace const aclTensor *gmm1PermutedWeightScale, const aclTensor *gmm2Weight, const aclTensor *gmm2WeightScale, const aclTensor *expertSmoothScalesOptional, const aclTensor *expertScalesOptional, char *groupEp, int64_t epRankSize, int64_t epRankId, int64_t moeExpertNum, int64_t shareExpertNum, int64_t shareExpertRankNum, - int64_t quantMode, int64_t globalBs, const aclTensor *output, uint64_t *workspaceSize, aclOpExecutor **executor); + int64_t quantMode, int64_t globalBs, const aclTensor *output, const aclTensor *outputRecvCount, + uint64_t *workspaceSize, aclOpExecutor **executor); __attribute__((visibility("default"))) aclnnStatus aclnnFusedDeepMoe(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream); diff --git a/csrc/deepep/ops/op_kernel/fused_deep_moe.cpp b/csrc/deepep/ops/op_kernel/fused_deep_moe.cpp index 5f48a599..8d25ddb6 100644 --- a/csrc/deepep/ops/op_kernel/fused_deep_moe.cpp +++ b/csrc/deepep/ops/op_kernel/fused_deep_moe.cpp @@ -15,19 +15,19 @@ extern "C" __global__ __aicore__ void fused_deep_moe( GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale, GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales, // output - GM_ADDR output, + GM_ADDR output, GM_ADDR outputRecvCount, // system GM_ADDR workspace, GM_ADDR tiling) { icache_preload(8); - + // New output recvCount REGISTER_TILING_DEFAULT(FusedDeepMoeTilingData); KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); // 1C2V GET_TILING_DATA(tiling_data, tiling); if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1)) { FusedDeepMoe op; op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale, - expert_smooth_scales, expert_scales, output, workspace, nullptr, &tiling_data); + expert_smooth_scales, expert_scales, output, outputRecvCount, workspace, nullptr, &tiling_data); op.Process(); } } diff --git a/csrc/deepep/ops/op_kernel/fused_deep_moe.h b/csrc/deepep/ops/op_kernel/fused_deep_moe.h index e0ac2be6..436d57a4 100644 --- a/csrc/deepep/ops/op_kernel/fused_deep_moe.h +++ b/csrc/deepep/ops/op_kernel/fused_deep_moe.h @@ -65,10 +65,10 @@ ACT_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCount, G layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD, GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale, GM_ADDR gmWorkspace, GM_ADDR gmX, GM_ADDR debugGm, GM_ADDR gmexpertIds, GM_ADDR gmExpandIdx, - GM_ADDR gmEpSendCount, GM_ADDR gmResvered, uint32_t epRankSize, uint32_t epRankId, - uint32_t moeExpertNum, uint32_t moeExpertNumPerRank, uint32_t sharedExpertNum, - uint32_t sharedExpertRankNum, uint32_t quantMode, uint32_t globalBs, uint32_t bs, - uint32_t topK) + GM_ADDR gmEpSendCount, GM_ADDR gmResvered, GM_ADDR gmOutputRecvCount, + uint32_t epRankSize, uint32_t epRankId, uint32_t moeExpertNum, + uint32_t moeExpertNumPerRank, uint32_t sharedExpertNum, uint32_t sharedExpertRankNum, + uint32_t quantMode, uint32_t globalBs, uint32_t bs, uint32_t topK) { using ArchTag = Arch::AtlasA2; using DispatchPolicy = DispatchPolicy_; @@ -140,6 +140,7 @@ ACT_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCount, G gmExpandIdx, gmEpSendCount, gmResvered, + gmOutputRecvCount, epRankSize, epRankId, moeExpertNum, @@ -244,7 +245,7 @@ class FusedDeepMoe GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale, GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales, // output - GM_ADDR output, + GM_ADDR output, GM_ADDR outputRecvCount, // system GM_ADDR workspaceGM, AscendC::TPipe *pipe, const FusedDeepMoeTilingData *tilingData); __aicore__ inline void Process(); @@ -257,6 +258,7 @@ class FusedDeepMoe GM_ADDR gmWeight2_; GM_ADDR gmScale2_; GM_ADDR gmOutput_; + GM_ADDR gmOutputRecvCount_; GM_ADDR workspaceGM_; GM_ADDR gmSmoothScales_; GM_ADDR gmexpertScales_; @@ -293,7 +295,7 @@ __aicore__ inline void FusedDeepMoe::Init( GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale, GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales, // output - GM_ADDR output, + GM_ADDR output, GM_ADDR outputRecvCount, // system GM_ADDR workspaceGM, AscendC::TPipe *pipe, const FusedDeepMoeTilingData *tilingData) { @@ -309,6 +311,7 @@ __aicore__ inline void FusedDeepMoe::Init( gmWeight2_ = gmm2_weight; gmScale2_ = gmm2_weight_scale; gmOutput_ = output; + gmOutputRecvCount_ = outputRecvCount; workspaceGM_ = workspaceGM; gmexpertScales_ = expert_scales; tilingData_ = tilingData; @@ -415,12 +418,12 @@ __aicore__ inline void FusedDeepMoe::Process() } } GmmDeqSwigluQuant(gmm1ProblemShape, groupCount_, gmGroupList, gmX1Token, layoutX1, - gmPermuteWeight1_, layoutWeight1, gmPermuteScale1_, layoutScale1, gmX1Scale, - layoutPerTokenScale1, gmX2, layoutX2, gmPerTokenScale2, layoutPerTokenScale2, - gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, - gmResvered, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, - sharedExpertNum_, sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_); + Gmm1BlockScheduler>( + gmm1ProblemShape, groupCount_, gmGroupList, gmX1Token, layoutX1, gmPermuteWeight1_, layoutWeight1, + gmPermuteScale1_, layoutScale1, gmX1Scale, layoutPerTokenScale1, gmX2, layoutX2, gmPerTokenScale2, + layoutPerTokenScale2, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, gmResvered, + gmOutputRecvCount_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_, + sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_); #ifdef ENABLE_GMM2_COMBINE AscendC::PipeBarrier(); Arch::CrossCoreFlag gmm1AivFinished{0}; diff --git a/csrc/deepep/ops/utils/op_kernel/operator/catlass/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp b/csrc/deepep/ops/utils/op_kernel/operator/catlass/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp index 4a59ac9b..d18dc276 100644 --- a/csrc/deepep/ops/utils/op_kernel/operator/catlass/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp +++ b/csrc/deepep/ops/utils/op_kernel/operator/catlass/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp @@ -193,7 +193,6 @@ class GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); - startCoreIdx = (startCoreIdx + coreLoops) % coreNum; } @@ -227,7 +226,6 @@ class GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace int64_t gmGroupOffsetScale = 0; int64_t gmGroupOffsetPerTokenScale = 0; int64_t gmGroupOffsetD = 0; - AscendC::GlobalTensor groupList; groupList.SetGlobalBuffer(params.ptrGroupList); @@ -246,14 +244,12 @@ class GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace LayoutPerTokenScale layoutPerTokenScale = params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); LayoutD layoutD = params.layoutD.GetTileLayout(inGroupProblemShape.GetCoordMN()); - EpilogueParams epilogueParams{params.ptrScale + gmGroupOffsetScale, layoutScale, params.ptrPerTokenScale + gmGroupOffsetPerTokenScale, layoutPerTokenScale, params.ptrD + gmGroupOffsetD, layoutD}; - blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); blockEpilogue.UpdateParams(epilogueParams); uint32_t coreLoops = blockScheduler.GetCoreLoops(); diff --git a/csrc/deepep/ops/utils/op_kernel/operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h b/csrc/deepep/ops/utils/op_kernel/operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h index bcb3454c..a545b3d2 100644 --- a/csrc/deepep/ops/utils/op_kernel/operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h +++ b/csrc/deepep/ops/utils/op_kernel/operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h @@ -7,6 +7,7 @@ * History: 2025-07-19 create FusedDeepMoe operator kernel function implementation file */ #pragma once + #include "../../catlass/act/act.hpp" #include "../../catlass/act/arch/cross_core_sync.hpp" #include "../../catlass/act/arch/resource.hpp" @@ -406,6 +407,7 @@ class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace GM_ADDR gmExpandIdx; GM_ADDR gmEpSendCount; GM_ADDR gmResvered; + GM_ADDR gmOutputRecvCount; uint32_t epRankSize; uint32_t epRankId; @@ -428,9 +430,9 @@ class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace LayoutPerTokenScale const &layoutPerTokenScale_, GM_ADDR ptrOutput_, LayoutOutput const &layoutOutput_, GM_ADDR ptrDequantScale_, LayoutDequantScale const &layoutDequantScale_, GM_ADDR ptrWorkspace_, GM_ADDR gmX_, GM_ADDR debugGm_, GM_ADDR gmexpertIds_, GM_ADDR gmExpandIdx_, GM_ADDR gmEpSendCount_, - GM_ADDR gmResvered_, uint32_t epRankSize_, uint32_t epRankId_, uint32_t moeExpertNum_, - uint32_t moeExpertNumPerRank_, uint32_t sharedExpertNum_, uint32_t sharedExpertRankNum_, - uint32_t quantMode_, uint32_t globalBs_, uint32_t bs_, uint32_t topK_) + GM_ADDR gmResvered_, GM_ADDR gmOutputRecvCount_, uint32_t epRankSize_, uint32_t epRankId_, + uint32_t moeExpertNum_, uint32_t moeExpertNumPerRank_, uint32_t sharedExpertNum_, + uint32_t sharedExpertRankNum_, uint32_t quantMode_, uint32_t globalBs_, uint32_t bs_, uint32_t topK_) : problemShape(problemShape_), problemCount(problemCount_), ptrGroupList(reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)), @@ -452,6 +454,7 @@ class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace gmexpertIds(gmexpertIds_), gmExpandIdx(gmExpandIdx_), gmEpSendCount(gmEpSendCount_), + gmOutputRecvCount(gmOutputRecvCount_), gmResvered(gmResvered_), epRankSize(epRankSize_), epRankId(epRankId_), @@ -519,6 +522,7 @@ class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace gmA.SetGlobalBuffer(params.ptrA); AscendC::GlobalTensor gmB; gmB.SetGlobalBuffer(params.ptrB); + AscendC::GlobalTensor groupList; groupList.SetGlobalBuffer(params.ptrGroupList); @@ -1055,7 +1059,7 @@ void RecvCount(int64_t ubOffset) } ACT_DEVICE -void GetCumSum(int32_t startRankId, int32_t recvExpertNum, int64_t ubOffset) +void GetCumSum(int32_t startRankId, int32_t recvExpertNum, int64_t ubOffset, GM_ADDR gmOutputRecvCount) { // 计算前缀和,目的是知道自己收到的token在output中的偏移 int64_t subUbOffset = ubOffset; @@ -1081,7 +1085,15 @@ void GetCumSum(int32_t startRankId, int32_t recvExpertNum, int64_t ubOffset) AscendC::WaitFlag(0); AscendC::GatherMask(gatherMaskOutTensor, statusFp32Tensor_, gatherTmpTensor, true, GATHER_SECOND_NUM, {1, (uint16_t)recStatusNumPerCore, 1, 0}, rsvdCnt); - + if (isRecvCore && recvCoreIdx == 0) { + AscendC::GlobalTensor recvCountTensor; + recvCountTensor.SetGlobalBuffer((__gm__ int32_t *)gmOutputRecvCount); + AscendC::DataCopyExtParams dataCopyParams = { + 1U, static_cast(localExpertNum * epRankSize * sizeof(int32_t)), 0U, 0U, 0U}; + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopyPad(recvCountTensor, gatherMaskOutTensor.ReinterpretCast(), dataCopyParams); + } // 这里是为ReduceSum准备所需空间,本应该计算好需要多大空间,但当前是给偏移,且用完就释放,所以就不计算了 AscendC::LocalTensor workLocalTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); AscendC::PipeBarrier(); @@ -1185,7 +1197,7 @@ void RecvToken(GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmEpSendCount, uint32_t } ACT_DEVICE -void RecvCoreFunc(GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmEpSendCount) +void RecvCoreFunc(GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmEpSendCount, GM_ADDR gmOutputRecvCount) { ubOffset = 0; RecvCount(ubOffset); @@ -1213,7 +1225,7 @@ void RecvCoreFunc(GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmEpSendCount) if (startRankId < recvExpertNum) { // 计算前缀和,以及接收token。这里有隐含约束,下面两个函数与RecvCount的ubOffset入参应保持一致,这样才能拿到有效数据 - GetCumSum(startRankId, recvExpertNum, ubOffset); + GetCumSum(startRankId, recvExpertNum, ubOffset, gmOutputRecvCount); RecvToken(gmX1, gmX1Scale, gmEpSendCount, coreTokenCount, startRankId, endRankId, recvRankNumPerCore, ubOffset); } @@ -1277,14 +1289,12 @@ void CompCoreFunc(GM_ADDR gmCVSwapBuff, __gm__ ElementScale *gmScale, __gm__ Ele LayoutPerTokenScale layoutPerTokenScale = wholeLayoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); LayoutD layoutD = layoutOutput.GetTileLayout(MakeCoord(currentM, nOut)); - EpilogueParams epilogueParams{gmScale + gmGroupOffsetScale, layoutScale, gmTokenScale + gmGroupOffsetPerTokenScale, layoutPerTokenScale, gmSwigluOutput + gmGroupOffsetD, layoutD}; - blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); blockEpilogue.UpdateParams(epilogueParams); uint32_t coreLoops = blockScheduler.GetCoreLoops(); @@ -1495,7 +1505,8 @@ ACT_DEVICE void operator()(Params const ¶ms) (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmExpandIdx); } if (isRecvCore) { - RecvCoreFunc((GM_ADDR)params.ptrA, (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmEpSendCount); + RecvCoreFunc((GM_ADDR)params.ptrA, (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmEpSendCount, + (GM_ADDR)params.gmOutputRecvCount); } auto gmSwigluOutput = reinterpret_cast<__gm__ float *>(