@@ -1072,10 +1072,38 @@ std::vector<size_t> CutlassMoeFCRunner<T, WeightType, OutputType, Enable>::getWo
1072
1072
size_t const hopper_size = using_hopper ? HopperGroupedGemmInput::workspaceSize (num_experts_per_node) : 0 ;
1073
1073
size_t const gemm_workspace_size = moe_gemm_runner_.getMaxWorkspaceSize (num_experts_per_node);
1074
1074
1075
- std::vector<size_t > workspace{source_rows_size, permuted_rows_size, permuted_experts_size, permuted_data_size,
1076
- total_rows_before_expert_size, softmax_out_size, glu_inter_size,
1075
+ // We do some overlapping of the large workspace buffers. Although we could overlap some of the other buffers, they
1076
+ // are small enough (i.e no factor of hidden size) they will only be a couple MiB at most, so we don't bother
1077
+ // in the case of fused activation we overlap permuted_data and fc2_result
1078
+ // in the case of unfused activation we overlap permuted_data and fc1_result
1079
+ // we need to calculate the max possible size, so use the max of all three
1080
+ size_t overlapped_gemm1_gemm2_inputs = std::max (permuted_data_size, fc2_result_size);
1081
+ // When glu_inter_elems is 0 we are always fused, otherwise we may need the un-fused case
1082
+ if (glu_inter_elems > 0 )
1083
+ {
1084
+ overlapped_gemm1_gemm2_inputs = std::max (overlapped_gemm1_gemm2_inputs, fc1_result_size);
1085
+ }
1086
+
1087
+ // if we have glu_inter we overlap it with fc2_result, otherwise we use fc1_result by itself
1088
+ size_t overlapped_gemm1_gemm2_outputs = fc1_result_size;
1089
+ if (glu_inter_elems > 0 )
1090
+ {
1091
+ overlapped_gemm1_gemm2_outputs
1092
+ = std::max (std::max (glu_inter_size, fc2_result_size), overlapped_gemm1_gemm2_outputs);
1093
+ }
1094
+
1095
+ std::vector<size_t > workspace{ //
1096
+ source_rows_size, //
1097
+ permuted_rows_size, //
1098
+ permuted_experts_size, //
1099
+ total_rows_before_expert_size, //
1100
+ softmax_out_size, //
1101
+ sorter_size, //
1077
1102
// These pointers reuse the same memory
1078
- std::max (fc1_result_size, sorter_size), fc2_result_size, hopper_size, gemm_workspace_size};
1103
+ overlapped_gemm1_gemm2_inputs, //
1104
+ overlapped_gemm1_gemm2_outputs, //
1105
+ hopper_size, //
1106
+ gemm_workspace_size};
1079
1107
return workspace;
1080
1108
}
1081
1109
@@ -1088,7 +1116,9 @@ size_t CutlassMoeFCRunner<T, WeightType, OutputType, Enable>::getWorkspaceSize(i
1088
1116
TLLM_CHECK_WITH_INFO (num_experts % ep_size == 0 , " Number of experts must be a multiple of ep size" );
1089
1117
auto workspace = getWorkspaceBufferSizes (
1090
1118
num_rows, hidden_size, inter_size, num_experts, num_experts / ep_size, k, activation_type);
1091
- return tensorrt_llm::common::calculateTotalWorkspaceSize (workspace.data (), workspace.size ());
1119
+ auto ws_size = tensorrt_llm::common::calculateTotalWorkspaceSize (workspace.data (), workspace.size ());
1120
+ TLLM_LOG_DEBUG (" Mixture Of Experts Plugin requires workspace of %2f MiB" , ws_size / 1024 .f / 1024 .f );
1121
+ return ws_size;
1092
1122
}
1093
1123
1094
1124
template <class T , class WeightType , class OutputType , class Enable >
@@ -1109,29 +1139,38 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, Enable>::configureWsPtrs(char
1109
1139
source_rows_ = (int *) ws_sliced[0 ];
1110
1140
permuted_rows_ = (int *) ws_sliced[1 ];
1111
1141
permuted_experts_ = (int *) ws_sliced[2 ];
1112
- permuted_data_ = (T*) ws_sliced[3 ];
1113
1142
1114
- total_rows_before_expert_ = (int64_t *) ws_sliced[4 ];
1143
+ total_rows_before_expert_ = (int64_t *) ws_sliced[3 ];
1115
1144
1116
1145
softmax_out_ = nullptr ;
1117
1146
bool const is_pow_2 = (num_experts != 0 ) && ((num_experts & (num_experts - 1 )) == 0 );
1118
1147
if (!is_pow_2 || num_experts > 256 )
1119
1148
{
1120
- softmax_out_ = (float *) ws_sliced[5 ];
1149
+ softmax_out_ = (float *) ws_sliced[4 ];
1121
1150
}
1122
1151
1123
- glu_inter_result_ = (T *) ws_sliced[6 ];
1152
+ sorter_ws_ = (char *) ws_sliced[5 ];
1124
1153
1125
- // These pointers are aliased. Since the sort ws can be overwritten after it is finished
1126
- sorter_ws_ = (char *) ws_sliced[7 ];
1127
- fc1_result_ = (T*) ws_sliced[7 ];
1154
+ // Always 6, but overlapped with either fc1_result_ or fc2_result_
1155
+ permuted_data_ = (T*) ws_sliced[6 ];
1128
1156
1129
- fc2_result_ = (T*) ws_sliced[8 ];
1157
+ bool const is_gated_activation = isGatedActivation (activation_type);
1158
+ bool const use_fused_moe = moe_gemm_runner_.isFusedGatedActivation (is_gated_activation, inter_size, hidden_size);
1159
+ bool const using_hopper = moe_gemm_runner_.isHopperSpecialised ();
1160
+ bool const hopper_has_glu = using_hopper && (mayHaveDifferentGEMMOutputType () || is_gated_activation);
1161
+ bool const non_hopper_has_glu = !using_hopper && !use_fused_moe && is_gated_activation;
1162
+ bool const has_glu_inter_result = hopper_has_glu || non_hopper_has_glu;
1163
+ // Always 7, ignored if not needed
1164
+ glu_inter_result_ = has_glu_inter_result ? (T*) ws_sliced[7 ] : nullptr ;
1165
+
1166
+ // fc1 and fc2 alias one of the above pointers, but it depends on if actfn is fused/unfused which is overlapped
1167
+ fc1_result_ = has_glu_inter_result ? (T*) ws_sliced[6 ] : (T*) ws_sliced[7 ];
1168
+ fc2_result_ = has_glu_inter_result ? (T*) ws_sliced[7 ] : (T*) ws_sliced[6 ];
1130
1169
1131
1170
hopper_grouped_gemm_input_ = {};
1132
1171
if (moe_gemm_runner_.isHopperSpecialised ())
1133
1172
{
1134
- hopper_grouped_gemm_input_.configureWorkspace (ws_sliced[9 ], num_experts_per_node, ws_sliced[10 ], ws_sizes[10 ]);
1173
+ hopper_grouped_gemm_input_.configureWorkspace (ws_sliced[8 ], num_experts_per_node, ws_sliced[9 ], ws_sizes[9 ]);
1135
1174
}
1136
1175
}
1137
1176
@@ -1293,6 +1332,7 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, Enable>::runMoe(void const* i
1293
1332
}
1294
1333
else
1295
1334
{
1335
+
1296
1336
// Run the GEMM with activation function overridden with `Identity`, we do the activation separately
1297
1337
ActivationType activation_type = (use_fused_moe) ? fc1_activation_type : ActivationType::Identity;
1298
1338
T* gemm_result = (use_fused_moe) ? fc1_result_ : static_cast <T*>(glu_inter_result_);
0 commit comments