Skip to content

Commit 43a042a

Browse files
committed
eplb
eplb_clean eplb_repair
1 parent dd005e3 commit 43a042a

12 files changed

+228
-54
lines changed

csrc/deepep/deep_ep.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -592,10 +592,11 @@ std::tuple<at::Tensor, std::optional<EventHandle>, std::optional<std::function<v
592592
return {combined_x, event, std::function<void()>([] {})};
593593
}
594594

595-
std::tuple<at::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> Buffer::fused_deep_moe(
596-
const at::Tensor &x, const at::Tensor &expertIds, const at::Tensor &gmm1PermutedWeight,
597-
const at::Tensor &gmm1PermutedWeightScale, const at::Tensor &gmm2Weight, const at::Tensor &gmm2WeightScale,
598-
const at::Tensor &expertScalesOptional, int64_t num_max_dispatch_tokens_per_rank, int64_t num_experts, bool use_fp8)
595+
std::tuple<std::vector<at::Tensor>, std::optional<EventHandle>, std::optional<std::function<void()>>>
596+
Buffer::fused_deep_moe(const at::Tensor &x, const at::Tensor &expertIds, const at::TensorList &gmm1PermutedWeight,
597+
const at::TensorList &gmm1PermutedWeightScale, const at::TensorList &gmm2Weight,
598+
const at::TensorList &gmm2WeightScale, const at::Tensor &expertScalesOptional,
599+
int64_t num_max_dispatch_tokens_per_rank, int64_t num_experts, bool use_fp8)
599600
{
600601
EP_HOST_ASSERT(expertIds.dim() == 2);
601602
EP_HOST_ASSERT(expertScalesOptional.dim() == 2);
@@ -659,6 +660,15 @@ std::tuple<at::Tensor, std::optional<EventHandle>, std::optional<std::function<v
659660
int bs = this->new_topk_idx.size(0);
660661
at::Tensor output = at::empty({bs, h}, x.options());
661662

663+
bool isShareExpert = (rank < shared_expert_num);
664+
int64_t localExpertNum = 0;
665+
if (isShareExpert) {
666+
localExpertNum = num_ranks;
667+
} else {
668+
localExpertNum = num_ranks * (num_experts / (num_ranks - shared_expert_num));
669+
}
670+
at::Tensor recvCountOutput = at::empty({localExpertNum}, expertIds.options());
671+
662672
EXEC_NPU_CMD(aclnnFusedDeepMoe,
663673
// input
664674
x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale, gmm2Weight, gmm2WeightScale,
@@ -667,7 +677,7 @@ std::tuple<at::Tensor, std::optional<EventHandle>, std::optional<std::function<v
667677
hcom_ep_name, num_ranks, rank, num_experts, shared_expert_num, shared_expert_rank_num, quantMode,
668678
globalBs,
669679
// output
670-
output);
680+
output, recvCountOutput);
671681

672682
// ---------- Unpadding ----------
673683
if (this->is_padding) {
@@ -680,6 +690,6 @@ std::tuple<at::Tensor, std::optional<EventHandle>, std::optional<std::function<v
680690
}
681691

682692
std::optional<EventHandle> event;
683-
return {output, event, std::function<void()>([] {})};
693+
return {{output, recvCountOutput}, event, std::function<void()>([] {})};
684694
}
685695
} // namespace deep_ep

csrc/deepep/deep_ep.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,10 @@ struct Buffer {
8585
const at::Tensor &packed_recv_count, bool zero_copy, bool async, bool return_recv_hook,
8686
const std::optional<at::Tensor> &out);
8787

88-
std::tuple<at::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
89-
fused_deep_moe(const at::Tensor &x, const at::Tensor &expertIds, const at::Tensor &gmm1PermutedWeight,
90-
const at::Tensor &gmm1PermutedWeightScale, const at::Tensor &gmm2Weight,
91-
const at::Tensor &gmm2WeightScale, const at::Tensor &expertScalesOptional,
88+
std::tuple<std::vector<at::Tensor>, std::optional<EventHandle>, std::optional<std::function<void()>>>
89+
fused_deep_moe(const at::Tensor &x, const at::Tensor &expertIds, const at::TensorList &gmm1PermutedWeight,
90+
const at::TensorList &gmm1PermutedWeightScale, const at::TensorList &gmm2Weight,
91+
const at::TensorList &gmm2WeightScale, const at::Tensor &expertScalesOptional,
9292
int64_t num_max_dispatch_tokens_per_rank, int64_t num_experts, bool use_fp8);
9393
};
9494
} // namespace deep_ep

csrc/deepep/ops/op_host/fused_deep_moe.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
*/
99
#include "register/op_def_registry.h"
1010

11+
#define ENABLE_TENSOR_LIST
12+
1113
namespace ops {
1214
class FusedDeepMoe : public OpDef
1315
{
@@ -24,6 +26,28 @@ class FusedDeepMoe : public OpDef
2426
.DataType({ge::DT_INT32, ge::DT_INT32})
2527
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
2628
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
29+
#ifdef ENABLE_TENSOR_LIST
30+
this->Input("gmm1_permuted_weight")
31+
.ParamType(DYNAMIC)
32+
.DataType({ge::DT_INT8, ge::DT_INT8})
33+
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
34+
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
35+
this->Input("gmm1_permuted_weight_scale")
36+
.ParamType(DYNAMIC)
37+
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
38+
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
39+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
40+
this->Input("gmm2_weight")
41+
.ParamType(DYNAMIC)
42+
.DataType({ge::DT_INT8, ge::DT_INT8})
43+
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
44+
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
45+
this->Input("gmm2_weight_scale")
46+
.ParamType(DYNAMIC)
47+
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
48+
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
49+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
50+
#else
2751
this->Input("gmm1_permuted_weight")
2852
.ParamType(REQUIRED)
2953
.DataType({ge::DT_INT8, ge::DT_INT8})
@@ -44,6 +68,7 @@ class FusedDeepMoe : public OpDef
4468
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
4569
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
4670
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
71+
#endif
4772
this->Input("expert_smooth_scales")
4873
.ParamType(OPTIONAL)
4974
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
@@ -59,6 +84,11 @@ class FusedDeepMoe : public OpDef
5984
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
6085
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
6186
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
87+
this->Output("ep_recv_count")
88+
.ParamType(REQUIRED)
89+
.DataType({ge::DT_INT32, ge::DT_INT32})
90+
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
91+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
6292
this->Attr("group_ep").String();
6393
this->Attr("ep_rank_size").Int();
6494
this->Attr("ep_rank_id").Int();

csrc/deepep/ops/op_host/fused_deep_moe_infer.cpp

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,27 @@ namespace ge {
1616
constexpr uint32_t EXPAND_X_INDEX = 0;
1717
constexpr uint32_t EXPERT_IDS_INDEX = 1;
1818
constexpr uint32_t OUTPUT_X_INDEX = 0;
19+
constexpr uint32_t OUTPUT_REC_COUNT_INDEX = 1;
20+
21+
constexpr uint32_t ATTR_GROUP_EP_INDEX = 0;
22+
constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1;
23+
constexpr uint32_t ATTR_EP_RANK_ID_INDEX = 2;
24+
constexpr uint32_t ATTR_MOE_EXPERT_NUM_INDEX = 3;
25+
constexpr uint32_t ATTR_SHARE_EXPERT_NUM_INDEX = 4;
26+
constexpr uint32_t ATTR_SHARE_EXPERT_RANK_NUM_INDEX = 5;
27+
constexpr uint32_t ATTR_QUANT_MODE_INDEX = 6;
28+
constexpr uint32_t ATTR_GLOBAL_BS_INDEX = 7;
1929

2030
static ge::graphStatus InferShape(gert::InferShapeContext *context)
2131
{
32+
const char *nodeName = context->GetNodeName();
33+
// infer output shape
2234
const gert::Shape *expandXShape = context->GetInputShape(EXPAND_X_INDEX);
2335
const gert::Shape *expertIdsShape = context->GetInputShape(EXPERT_IDS_INDEX);
2436
gert::Shape *expandXOutShape = context->GetOutputShape(OUTPUT_X_INDEX);
25-
26-
if (expandXShape == nullptr || expertIdsShape == nullptr || expandXOutShape == nullptr) {
37+
gert::Shape *recvCountOutShape = context->GetOutputShape(OUTPUT_REC_COUNT_INDEX);
38+
if (expandXShape == nullptr || expertIdsShape == nullptr || expandXOutShape == nullptr ||
39+
recvCountOutShape == nullptr) {
2740
return GRAPH_FAILED;
2841
}
2942
if (expandXShape->GetDimNum() < 2 || expertIdsShape->GetDimNum() < 1) {
@@ -37,13 +50,42 @@ static ge::graphStatus InferShape(gert::InferShapeContext *context)
3750
expandXOutShape->SetDim(0, bs);
3851
expandXOutShape->SetDim(1, h);
3952

53+
// infer recvCount shape
54+
auto attrs = context->GetAttrs();
55+
OP_TILING_CHECK(attrs == nullptr, OP_LOGE(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED);
56+
57+
auto epRankSizePtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_RANK_SIZE_INDEX);
58+
auto epRankIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_RANK_ID_INDEX);
59+
auto moeExpertNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_MOE_EXPERT_NUM_INDEX);
60+
auto sharedExpertRankNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_SHARE_EXPERT_RANK_NUM_INDEX);
61+
62+
OP_TILING_CHECK(epRankIdPtr == nullptr, OP_LOGE(nodeName, "epRankIdPtr is nullptr."), return ge::GRAPH_FAILED);
63+
OP_TILING_CHECK(moeExpertNumPtr == nullptr, OP_LOGE(nodeName, "moeExpertNumPtr is nullptr."),
64+
return ge::GRAPH_FAILED);
65+
OP_TILING_CHECK(epRankSizePtr == nullptr, OP_LOGE(nodeName, "epRankSizePtr is nullptr."), return ge::GRAPH_FAILED);
66+
OP_TILING_CHECK(sharedExpertRankNumPtr == nullptr, OP_LOGE(nodeName, "sharedExpertRankNumPtr is nullptr."),
67+
return ge::GRAPH_FAILED);
68+
uint32_t epRankSize = static_cast<uint32_t>(*epRankSizePtr);
69+
uint32_t moeExpertNum = static_cast<uint32_t>(*moeExpertNumPtr);
70+
uint32_t epRankId = static_cast<uint32_t>(*epRankIdPtr);
71+
uint32_t sharedExpertRankNum = static_cast<uint32_t>(*sharedExpertRankNumPtr);
72+
73+
recvCountOutShape->SetDimNum(1);
74+
bool isShareExpert = (epRankId < sharedExpertRankNum);
75+
if (isShareExpert) {
76+
recvCountOutShape->SetDim(0, epRankSize);
77+
} else {
78+
recvCountOutShape->SetDim(0, epRankSize * (moeExpertNum / (epRankSize - sharedExpertRankNum)));
79+
}
80+
4081
return GRAPH_SUCCESS;
4182
}
4283

4384
static ge::graphStatus InferDataType(gert::InferDataTypeContext *context)
4485
{
4586
const auto expandXDataType = context->GetInputDataType(EXPAND_X_INDEX);
4687
context->SetOutputDataType(OUTPUT_X_INDEX, expandXDataType);
88+
context->SetOutputDataType(OUTPUT_REC_COUNT_INDEX, ge::DT_INT32);
4789
return ge::GRAPH_SUCCESS;
4890
}
4991

csrc/deepep/ops/op_host/fused_deep_moe_tiling.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#include "tiling/hccl/hccl_tiling.h"
1919

2020
#define GM_ALIGN_SIZE 512
21-
#define ENABLE_TILING_CHECK
21+
// #define ENABLE_TILING_CHECK
2222

2323
using namespace ge;
2424
namespace {

csrc/deepep/ops/op_host/op_api/aclnn_fused_deep_moe.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include "aclnn/opdev/platform.h"
1313
#include "aclnnInner_fused_deep_moe.h"
1414

15+
#define ENABLE_TENSOR_LIST
16+
1517
enum class NnopbaseHcclServerType {
1618
NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0,
1719
NNOPBASE_HCCL_SERVER_TYPE_MTE,
@@ -23,17 +25,27 @@ extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor,
2325
extern "C" {
2426
#endif
2527

26-
aclnnStatus aclnnFusedDeepMoeGetWorkspaceSize(
27-
const aclTensor *x, const aclTensor *expertIds, const aclTensor *gmm1PermutedWeight,
28-
const aclTensor *gmm1PermutedWeightScale, const aclTensor *gmm2Weight, const aclTensor *gmm2WeightScale,
29-
const aclTensor *expertSmoothScalesOptional, const aclTensor *expertScalesOptional, char *groupEp,
30-
int64_t epRankSize, int64_t epRankId, int64_t moeExpertNum, int64_t shareExpertNum, int64_t shareExpertRankNum,
31-
int64_t quantMode, int64_t globalBs, const aclTensor *output, uint64_t *workspaceSize, aclOpExecutor **executor)
28+
aclnnStatus aclnnFusedDeepMoeGetWorkspaceSize(const aclTensor *x, const aclTensor *expertIds,
29+
#ifdef ENABLE_TENSOR_LIST
30+
const aclTensorList *gmm1PermutedWeight,
31+
const aclTensorList *gmm1PermutedWeightScale,
32+
const aclTensorList *gmm2Weight, const aclTensorList *gmm2WeightScale,
33+
#else
34+
const aclTensor *gmm1PermutedWeight,
35+
const aclTensor *gmm1PermutedWeightScale, const aclTensor *gmm2Weight,
36+
const aclTensor *gmm2WeightScale,
37+
#endif
38+
const aclTensor *expertSmoothScalesOptional,
39+
const aclTensor *expertScalesOptional, char *groupEp, int64_t epRankSize,
40+
int64_t epRankId, int64_t moeExpertNum, int64_t shareExpertNum,
41+
int64_t shareExpertRankNum, int64_t quantMode, int64_t globalBs,
42+
const aclTensor *output, const aclTensor *outputRecvCount,
43+
uint64_t *workspaceSize, aclOpExecutor **executor)
3244
{
3345
return aclnnInnerFusedDeepMoeGetWorkspaceSize(
3446
x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale, gmm2Weight, gmm2WeightScale,
3547
expertSmoothScalesOptional, expertScalesOptional, groupEp, epRankSize, epRankId, moeExpertNum, shareExpertNum,
36-
shareExpertRankNum, quantMode, globalBs, output, workspaceSize, executor);
48+
shareExpertRankNum, quantMode, globalBs, output, outputRecvCount, workspaceSize, executor);
3749
}
3850

3951
aclnnStatus aclnnFusedDeepMoe(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)

csrc/deepep/ops/op_host/op_api/aclnn_fused_deep_moe.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,27 @@
1010
#ifndef FUSED_DEEP_MOE
1111
#define FUSED_DEEP_MOE
1212

13+
#define ENABLE_TENSOR_LIST
14+
1315
#include "aclnn/acl_meta.h"
1416

1517
#ifdef __cplusplus
1618
extern "C" {
1719
#endif
1820

1921
__attribute__((visibility("default"))) aclnnStatus aclnnFusedDeepMoeGetWorkspaceSize(
20-
const aclTensor *x, const aclTensor *expertIds, const aclTensor *gmm1PermutedWeight,
21-
const aclTensor *gmm1PermutedWeightScale, const aclTensor *gmm2Weight, const aclTensor *gmm2WeightScale,
22+
const aclTensor *x, const aclTensor *expertIds,
23+
#ifdef ENABLE_TENSOR_LIST
24+
const aclTensorList *gmm1PermutedWeight, const aclTensorList *gmm1PermutedWeightScale,
25+
const aclTensorList *gmm2Weight, const aclTensorList *gmm2WeightScale,
26+
#else
27+
const aclTensor *gmm1PermutedWeight, const aclTensor *gmm1PermutedWeightScale, const aclTensor *gmm2Weight,
28+
const aclTensor *gmm2WeightScale,
29+
#endif
2230
const aclTensor *expertSmoothScalesOptional, const aclTensor *expertScalesOptional, char *groupEp,
2331
int64_t epRankSize, int64_t epRankId, int64_t moeExpertNum, int64_t shareExpertNum, int64_t shareExpertRankNum,
24-
int64_t quantMode, int64_t globalBs, const aclTensor *output, uint64_t *workspaceSize, aclOpExecutor **executor);
32+
int64_t quantMode, int64_t globalBs, const aclTensor *output, const aclTensor *outputRecvCount,
33+
uint64_t *workspaceSize, aclOpExecutor **executor);
2534

2635
__attribute__((visibility("default"))) aclnnStatus aclnnFusedDeepMoe(void *workspace, uint64_t workspaceSize,
2736
aclOpExecutor *executor, aclrtStream stream);

csrc/deepep/ops/op_kernel/fused_deep_moe.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ extern "C" __global__ __aicore__ void fused_deep_moe(
1515
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
1616
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales,
1717
// output
18-
GM_ADDR output,
18+
GM_ADDR output, GM_ADDR outputRecvCount,
1919
// system
2020
GM_ADDR workspace, GM_ADDR tiling)
2121
{
@@ -27,7 +27,7 @@ extern "C" __global__ __aicore__ void fused_deep_moe(
2727
if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1)) {
2828
FusedDeepMoe<DTYPE_X, int32_t, false, TILING_KEY_VAR> op;
2929
op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale,
30-
expert_smooth_scales, expert_scales, output, workspace, nullptr, &tiling_data);
30+
expert_smooth_scales, expert_scales, output, outputRecvCount, workspace, nullptr, &tiling_data);
3131
op.Process();
3232
}
3333
}

0 commit comments

Comments
 (0)