Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions csrc/deepep/ops/op_host/fused_deep_moe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ class FusedDeepMoe : public OpDef
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("ep_recv_count")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Attr("group_ep").String();
this->Attr("ep_rank_size").Int();
this->Attr("ep_rank_id").Int();
Expand Down
46 changes: 44 additions & 2 deletions csrc/deepep/ops/op_host/fused_deep_moe_infer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,27 @@ namespace ge {
constexpr uint32_t EXPAND_X_INDEX = 0;
constexpr uint32_t EXPERT_IDS_INDEX = 1;
constexpr uint32_t OUTPUT_X_INDEX = 0;
constexpr uint32_t OUTPUT_REC_COUNT_INDEX = 1;

constexpr uint32_t ATTR_GROUP_EP_INDEX = 0;
constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1;
constexpr uint32_t ATTR_EP_RANK_ID_INDEX = 2;
constexpr uint32_t ATTR_MOE_EXPERT_NUM_INDEX = 3;
constexpr uint32_t ATTR_SHARE_EXPERT_NUM_INDEX = 4;
constexpr uint32_t ATTR_SHARE_EXPERT_RANK_NUM_INDEX = 5;
constexpr uint32_t ATTR_QUANT_MODE_INDEX = 6;
constexpr uint32_t ATTR_GLOBAL_BS_INDEX = 7;

static ge::graphStatus InferShape(gert::InferShapeContext *context)
{
const char *nodeName = context->GetNodeName();
// infer output shape
const gert::Shape *expandXShape = context->GetInputShape(EXPAND_X_INDEX);
const gert::Shape *expertIdsShape = context->GetInputShape(EXPERT_IDS_INDEX);
gert::Shape *expandXOutShape = context->GetOutputShape(OUTPUT_X_INDEX);

if (expandXShape == nullptr || expertIdsShape == nullptr || expandXOutShape == nullptr) {
gert::Shape *recvCountOutShape = context->GetOutputShape(OUTPUT_REC_COUNT_INDEX);
if (expandXShape == nullptr || expertIdsShape == nullptr || expandXOutShape == nullptr ||
recvCountOutShape == nullptr) {
return GRAPH_FAILED;
}
if (expandXShape->GetDimNum() < 2 || expertIdsShape->GetDimNum() < 1) {
Expand All @@ -37,13 +50,42 @@ static ge::graphStatus InferShape(gert::InferShapeContext *context)
expandXOutShape->SetDim(0, bs);
expandXOutShape->SetDim(1, h);

// infer recvCount shape
auto attrs = context->GetAttrs();
OP_TILING_CHECK(attrs == nullptr, OP_LOGE(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED);

auto epRankSizePtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_RANK_SIZE_INDEX);
auto epRankIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_RANK_ID_INDEX);
auto moeExpertNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_MOE_EXPERT_NUM_INDEX);
auto sharedExpertRankNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_SHARE_EXPERT_RANK_NUM_INDEX);

OP_TILING_CHECK(epRankIdPtr == nullptr, OP_LOGE(nodeName, "epRankIdPtr is nullptr."), return ge::GRAPH_FAILED);
OP_TILING_CHECK(moeExpertNumPtr == nullptr, OP_LOGE(nodeName, "moeExpertNumPtr is nullptr."),
return ge::GRAPH_FAILED);
OP_TILING_CHECK(epRankSizePtr == nullptr, OP_LOGE(nodeName, "epRankSizePtr is nullptr."), return ge::GRAPH_FAILED);
OP_TILING_CHECK(sharedExpertRankNumPtr == nullptr, OP_LOGE(nodeName, "sharedExpertRankNumPtr is nullptr."),
return ge::GRAPH_FAILED);
uint32_t epRankSize = static_cast<uint32_t>(*epRankSizePtr);
uint32_t moeExpertNum = static_cast<uint32_t>(*moeExpertNumPtr);
uint32_t epRankId = static_cast<uint32_t>(*epRankIdPtr);
uint32_t sharedExpertRankNum = static_cast<uint32_t>(*sharedExpertRankNumPtr);

recvCountOutShape->SetDimNum(1);
bool isShareExpert = (epRankId < sharedExpertRankNum);
if (isShareExpert) {
recvCountOutShape->SetDim(0, epRankSize);
} else {
recvCountOutShape->SetDim(0, epRankSize * (moeExpertNum / (epRankSize - sharedExpertRankNum)));
}

return GRAPH_SUCCESS;
}

static ge::graphStatus InferDataType(gert::InferDataTypeContext *context)
{
const auto expandXDataType = context->GetInputDataType(EXPAND_X_INDEX);
context->SetOutputDataType(OUTPUT_X_INDEX, expandXDataType);
context->SetOutputDataType(OUTPUT_REC_COUNT_INDEX, ge::DT_INT32);
return ge::GRAPH_SUCCESS;
}

Expand Down
5 changes: 3 additions & 2 deletions csrc/deepep/ops/op_host/op_api/aclnn_fused_deep_moe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ aclnnStatus aclnnFusedDeepMoeGetWorkspaceSize(
const aclTensor *gmm1PermutedWeightScale, const aclTensor *gmm2Weight, const aclTensor *gmm2WeightScale,
const aclTensor *expertSmoothScalesOptional, const aclTensor *expertScalesOptional, char *groupEp,
int64_t epRankSize, int64_t epRankId, int64_t moeExpertNum, int64_t shareExpertNum, int64_t shareExpertRankNum,
int64_t quantMode, int64_t globalBs, const aclTensor *output, uint64_t *workspaceSize, aclOpExecutor **executor)
int64_t quantMode, int64_t globalBs, const aclTensor *output, const aclTensor *outputRecvCount,
uint64_t *workspaceSize, aclOpExecutor **executor)
{
return aclnnInnerFusedDeepMoeGetWorkspaceSize(
x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale, gmm2Weight, gmm2WeightScale,
expertSmoothScalesOptional, expertScalesOptional, groupEp, epRankSize, epRankId, moeExpertNum, shareExpertNum,
shareExpertRankNum, quantMode, globalBs, output, workspaceSize, executor);
shareExpertRankNum, quantMode, globalBs, output, outputRecvCount, workspaceSize, executor);
}

aclnnStatus aclnnFusedDeepMoe(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)
Expand Down
3 changes: 2 additions & 1 deletion csrc/deepep/ops/op_host/op_api/aclnn_fused_deep_moe.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ __attribute__((visibility("default"))) aclnnStatus aclnnFusedDeepMoeGetWorkspace
const aclTensor *gmm1PermutedWeightScale, const aclTensor *gmm2Weight, const aclTensor *gmm2WeightScale,
const aclTensor *expertSmoothScalesOptional, const aclTensor *expertScalesOptional, char *groupEp,
int64_t epRankSize, int64_t epRankId, int64_t moeExpertNum, int64_t shareExpertNum, int64_t shareExpertRankNum,
int64_t quantMode, int64_t globalBs, const aclTensor *output, uint64_t *workspaceSize, aclOpExecutor **executor);
int64_t quantMode, int64_t globalBs, const aclTensor *output, const aclTensor *outputRecvCount,
uint64_t *workspaceSize, aclOpExecutor **executor);

__attribute__((visibility("default"))) aclnnStatus aclnnFusedDeepMoe(void *workspace, uint64_t workspaceSize,
aclOpExecutor *executor, aclrtStream stream);
Expand Down
6 changes: 3 additions & 3 deletions csrc/deepep/ops/op_kernel/fused_deep_moe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@ extern "C" __global__ __aicore__ void fused_deep_moe(
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales,
// output
GM_ADDR output,
GM_ADDR output, GM_ADDR outputRecvCount,
// system
GM_ADDR workspace, GM_ADDR tiling)
{
icache_preload(8);

// New output recvCount
REGISTER_TILING_DEFAULT(FusedDeepMoeTilingData);
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); // 1C2V
GET_TILING_DATA(tiling_data, tiling);
if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1)) {
FusedDeepMoe<DTYPE_X, int32_t, false, TILING_KEY_VAR> op;
op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale,
expert_smooth_scales, expert_scales, output, workspace, nullptr, &tiling_data);
expert_smooth_scales, expert_scales, output, outputRecvCount, workspace, nullptr, &tiling_data);
op.Process();
}
}
27 changes: 15 additions & 12 deletions csrc/deepep/ops/op_kernel/fused_deep_moe.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ ACT_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCount, G
layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale, GM_ADDR gmWorkspace,
GM_ADDR gmX, GM_ADDR debugGm, GM_ADDR gmexpertIds, GM_ADDR gmExpandIdx,
GM_ADDR gmEpSendCount, GM_ADDR gmResvered, uint32_t epRankSize, uint32_t epRankId,
uint32_t moeExpertNum, uint32_t moeExpertNumPerRank, uint32_t sharedExpertNum,
uint32_t sharedExpertRankNum, uint32_t quantMode, uint32_t globalBs, uint32_t bs,
uint32_t topK)
GM_ADDR gmEpSendCount, GM_ADDR gmResvered, GM_ADDR gmOutputRecvCount,
uint32_t epRankSize, uint32_t epRankId, uint32_t moeExpertNum,
uint32_t moeExpertNumPerRank, uint32_t sharedExpertNum, uint32_t sharedExpertRankNum,
uint32_t quantMode, uint32_t globalBs, uint32_t bs, uint32_t topK)
{
using ArchTag = Arch::AtlasA2;
using DispatchPolicy = DispatchPolicy_;
Expand Down Expand Up @@ -140,6 +140,7 @@ ACT_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCount, G
gmExpandIdx,
gmEpSendCount,
gmResvered,
gmOutputRecvCount,
epRankSize,
epRankId,
moeExpertNum,
Expand Down Expand Up @@ -244,7 +245,7 @@ class FusedDeepMoe
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales,
// output
GM_ADDR output,
GM_ADDR output, GM_ADDR outputRecvCount,
// system
GM_ADDR workspaceGM, AscendC::TPipe *pipe, const FusedDeepMoeTilingData *tilingData);
__aicore__ inline void Process();
Expand All @@ -257,6 +258,7 @@ class FusedDeepMoe
GM_ADDR gmWeight2_;
GM_ADDR gmScale2_;
GM_ADDR gmOutput_;
GM_ADDR gmOutputRecvCount_;
GM_ADDR workspaceGM_;
GM_ADDR gmSmoothScales_;
GM_ADDR gmexpertScales_;
Expand Down Expand Up @@ -293,7 +295,7 @@ __aicore__ inline void FusedDeepMoe<TemplateMC2TypeFunc>::Init(
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales,
// output
GM_ADDR output,
GM_ADDR output, GM_ADDR outputRecvCount,
// system
GM_ADDR workspaceGM, AscendC::TPipe *pipe, const FusedDeepMoeTilingData *tilingData)
{
Expand All @@ -309,6 +311,7 @@ __aicore__ inline void FusedDeepMoe<TemplateMC2TypeFunc>::Init(
gmWeight2_ = gmm2_weight;
gmScale2_ = gmm2_weight_scale;
gmOutput_ = output;
gmOutputRecvCount_ = outputRecvCount;
workspaceGM_ = workspaceGM;
gmexpertScales_ = expert_scales;
tilingData_ = tilingData;
Expand Down Expand Up @@ -415,12 +418,12 @@ __aicore__ inline void FusedDeepMoe<TemplateMC2TypeFunc>::Process()
}
}
GmmDeqSwigluQuant<EXEC_FLAG, ExpandXType, Gmm1L1TileShape, Gmm1L0TileShape, Gmm1EpilogueTileShape,
Gmm1BlockScheduler>(gmm1ProblemShape, groupCount_, gmGroupList, gmX1Token, layoutX1,
gmPermuteWeight1_, layoutWeight1, gmPermuteScale1_, layoutScale1, gmX1Scale,
layoutPerTokenScale1, gmX2, layoutX2, gmPerTokenScale2, layoutPerTokenScale2,
gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount,
gmResvered, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_,
sharedExpertNum_, sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_);
Gmm1BlockScheduler>(
gmm1ProblemShape, groupCount_, gmGroupList, gmX1Token, layoutX1, gmPermuteWeight1_, layoutWeight1,
gmPermuteScale1_, layoutScale1, gmX1Scale, layoutPerTokenScale1, gmX2, layoutX2, gmPerTokenScale2,
layoutPerTokenScale2, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, gmResvered,
gmOutputRecvCount_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_,
sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_);
#ifdef ENABLE_GMM2_COMBINE
AscendC::PipeBarrier<PIPE_ALL>();
Arch::CrossCoreFlag gmm1AivFinished{0};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ class GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace

gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k();
gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n();

startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
}

Expand Down Expand Up @@ -227,7 +226,6 @@ class GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace
int64_t gmGroupOffsetScale = 0;
int64_t gmGroupOffsetPerTokenScale = 0;
int64_t gmGroupOffsetD = 0;

AscendC::GlobalTensor<ElementGroupList> groupList;
groupList.SetGlobalBuffer(params.ptrGroupList);

Expand All @@ -246,14 +244,12 @@ class GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace
LayoutPerTokenScale layoutPerTokenScale =
params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>());
LayoutD layoutD = params.layoutD.GetTileLayout(inGroupProblemShape.GetCoordMN());

EpilogueParams epilogueParams{params.ptrScale + gmGroupOffsetScale,
layoutScale,
params.ptrPerTokenScale + gmGroupOffsetPerTokenScale,
layoutPerTokenScale,
params.ptrD + gmGroupOffsetD,
layoutD};

blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN());
blockEpilogue.UpdateParams(epilogueParams);
uint32_t coreLoops = blockScheduler.GetCoreLoops();
Expand Down
Loading