@@ -16,14 +16,27 @@ namespace ge {
1616constexpr uint32_t EXPAND_X_INDEX = 0 ;
1717constexpr uint32_t EXPERT_IDS_INDEX = 1 ;
1818constexpr uint32_t OUTPUT_X_INDEX = 0 ;
19+ constexpr uint32_t OUTPUT_REC_COUNT_INDEX = 1 ;
20+
21+ constexpr uint32_t ATTR_GROUP_EP_INDEX = 0 ;
22+ constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1 ;
23+ constexpr uint32_t ATTR_EP_RANK_ID_INDEX = 2 ;
24+ constexpr uint32_t ATTR_MOE_EXPERT_NUM_INDEX = 3 ;
25+ constexpr uint32_t ATTR_SHARE_EXPERT_NUM_INDEX = 4 ;
26+ constexpr uint32_t ATTR_SHARE_EXPERT_RANK_NUM_INDEX = 5 ;
27+ constexpr uint32_t ATTR_QUANT_MODE_INDEX = 6 ;
28+ constexpr uint32_t ATTR_GLOBAL_BS_INDEX = 7 ;
1929
2030static ge::graphStatus InferShape (gert::InferShapeContext *context)
2131{
32+ const char *nodeName = context->GetNodeName ();
33+ // infer output shape
2234 const gert::Shape *expandXShape = context->GetInputShape (EXPAND_X_INDEX);
2335 const gert::Shape *expertIdsShape = context->GetInputShape (EXPERT_IDS_INDEX);
2436 gert::Shape *expandXOutShape = context->GetOutputShape (OUTPUT_X_INDEX);
25-
26- if (expandXShape == nullptr || expertIdsShape == nullptr || expandXOutShape == nullptr ) {
37+ gert::Shape *recvCountOutShape = context->GetOutputShape (OUTPUT_REC_COUNT_INDEX);
38+ if (expandXShape == nullptr || expertIdsShape == nullptr || expandXOutShape == nullptr ||
39+ recvCountOutShape == nullptr ) {
2740 return GRAPH_FAILED;
2841 }
2942 if (expandXShape->GetDimNum () < 2 || expertIdsShape->GetDimNum () < 1 ) {
@@ -37,13 +50,42 @@ static ge::graphStatus InferShape(gert::InferShapeContext *context)
3750 expandXOutShape->SetDim (0 , bs);
3851 expandXOutShape->SetDim (1 , h);
3952
53+ // infer recvCount shape
54+ auto attrs = context->GetAttrs ();
55+ OP_TILING_CHECK (attrs == nullptr , OP_LOGE (nodeName, " attrs is nullptr." ), return ge::GRAPH_FAILED);
56+
57+ auto epRankSizePtr = attrs->GetAttrPointer <int64_t >(ATTR_EP_RANK_SIZE_INDEX);
58+ auto epRankIdPtr = attrs->GetAttrPointer <int64_t >(ATTR_EP_RANK_ID_INDEX);
59+ auto moeExpertNumPtr = attrs->GetAttrPointer <int64_t >(ATTR_MOE_EXPERT_NUM_INDEX);
60+ auto sharedExpertRankNumPtr = attrs->GetAttrPointer <int64_t >(ATTR_SHARE_EXPERT_RANK_NUM_INDEX);
61+
62+ OP_TILING_CHECK (epRankIdPtr == nullptr , OP_LOGE (nodeName, " epRankIdPtr is nullptr." ), return ge::GRAPH_FAILED);
63+ OP_TILING_CHECK (moeExpertNumPtr == nullptr , OP_LOGE (nodeName, " moeExpertNumPtr is nullptr." ),
64+ return ge::GRAPH_FAILED);
65+ OP_TILING_CHECK (epRankSizePtr == nullptr , OP_LOGE (nodeName, " epRankSizePtr is nullptr." ), return ge::GRAPH_FAILED);
66+ OP_TILING_CHECK (sharedExpertRankNumPtr == nullptr , OP_LOGE (nodeName, " sharedExpertRankNumPtr is nullptr." ),
67+ return ge::GRAPH_FAILED);
68+ uint32_t epRankSize = static_cast <uint32_t >(*epRankSizePtr);
69+ uint32_t moeExpertNum = static_cast <uint32_t >(*moeExpertNumPtr);
70+ uint32_t epRankId = static_cast <uint32_t >(*epRankIdPtr);
71+ uint32_t sharedExpertRankNum = static_cast <uint32_t >(*sharedExpertRankNumPtr);
72+
73+ recvCountOutShape->SetDimNum (1 );
74+ bool isShareExpert = (epRankId < sharedExpertRankNum);
75+ if (isShareExpert) {
76+ recvCountOutShape->SetDim (0 , epRankSize);
77+ } else {
78+ recvCountOutShape->SetDim (0 , epRankSize * (moeExpertNum / (epRankSize - sharedExpertRankNum)));
79+ }
80+
4081 return GRAPH_SUCCESS;
4182}
4283
4384static ge::graphStatus InferDataType (gert::InferDataTypeContext *context)
4485{
4586 const auto expandXDataType = context->GetInputDataType (EXPAND_X_INDEX);
4687 context->SetOutputDataType (OUTPUT_X_INDEX, expandXDataType);
88+ context->SetOutputDataType (OUTPUT_REC_COUNT_INDEX, ge::DT_INT32);
4789 return ge::GRAPH_SUCCESS;
4890}
4991
0 commit comments