@@ -39,7 +39,9 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, std::option
39
39
int64_t const routing_method_type, MoeRunnerType& moe_runner, int64_t moeConfigIndex)
40
40
{
41
41
TORCH_CHECK (tensorrt_llm::common::isSM100Family (), " Only SM100f is supported by FP8 block scale MOE" );
42
- TORCH_CHECK (routing_logits.scalar_type () == at::ScalarType::Float, " routing_logits must be float." );
42
+ TORCH_CHECK (routing_logits.scalar_type () == at::ScalarType::Float
43
+ || routing_logits.scalar_type () == at::ScalarType::BFloat16,
44
+ " routing_logits must be float or bfloat16." );
43
45
TORCH_CHECK (routing_logits.dim () == 2 , " routing_logits must be 2D." );
44
46
TORCH_CHECK (routing_logits.sizes ()[0 ] == hidden_states.sizes ()[0 ],
45
47
" routing_logits and hidden_states must have the same number of tokens." );
@@ -69,7 +71,8 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, std::option
69
71
else if (static_cast <RoutingMethodType>(routing_method_type) == RoutingMethodType::Renormalize
70
72
|| static_cast <RoutingMethodType>(routing_method_type) == RoutingMethodType::RenormalizeNaive)
71
73
{
72
- TORCH_CHECK (false , " Don't support this routing method type Renormalize(Naive)." );
74
+ TORCH_CHECK (top_k <= 8 && top_k > 0 ,
75
+ " Current routing kernel (no groups, renormalize) only supports top_k<=8 && top_k>0." );
73
76
}
74
77
else if (static_cast <RoutingMethodType>(routing_method_type) == RoutingMethodType::Llama4)
75
78
{
@@ -89,7 +92,7 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, std::option
89
92
= routing_bias.has_value () ? routing_bias.value ().scalar_type () : at::ScalarType::BFloat16;
90
93
args.mDtypeExpW = routing_bias_dtype == at::ScalarType::Float ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16;
91
94
92
- args.routing_logits = routing_logits.data_ptr < float > ();
95
+ args.routing_logits = routing_logits.data_ptr ();
93
96
args.routing_bias = routing_bias.has_value () ? routing_bias.value ().data_ptr () : nullptr ;
94
97
args.hidden_states = hidden_states.data_ptr ();
95
98
args.hidden_states_scale = hidden_states_scale.data_ptr <float >();
@@ -153,13 +156,13 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, std::option
153
156
154
157
tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::Runner routing_runner (tile_tokens_dim);
155
158
auto const & stream = at::cuda::getCurrentCUDAStream (routing_logits.get_device ());
156
- routing_runner.run (routing_logits. data_ptr < float >() , args.routing_bias , args.num_tokens , args.num_experts ,
157
- args.top_k , args. n_group , args.topk_group , args.local_expert_offset , args.local_num_experts ,
158
- args. routed_scaling_factor , expert_indexes .data_ptr <int >(), expert_count_histogram .data_ptr <int >(),
159
- total_num_padded_tokens .data_ptr <int >(), expanded_idx_to_permuted_idx .data_ptr <int >(),
160
- nullptr /* permuted_idx_to_expanded_idx .data_ptr<int>()*/ , permuted_idx_to_token_idx .data_ptr <int >(),
161
- expert_weights .data_ptr (), num_tokens_per_expert. data_ptr <int >(), cta_idx_xy_to_batch_idx .data_ptr <int >(),
162
- cta_idx_xy_to_mn_limit. data_ptr < int >(), num_non_exiting_ctas.data_ptr <int >(), args.mDtypeElt , false , true ,
159
+ routing_runner.run (args. routing_logits , args.routing_bias , args.num_tokens , args.num_experts , args. top_k ,
160
+ args.n_group , args.topk_group , args.local_expert_offset , args.local_num_experts , args. routed_scaling_factor ,
161
+ expert_indexes. data_ptr < int >(), expert_count_histogram .data_ptr <int >(), total_num_padded_tokens .data_ptr <int >(),
162
+ expanded_idx_to_permuted_idx .data_ptr <int >(), nullptr /* permuted_idx_to_expanded_idx .data_ptr<int>()*/ ,
163
+ permuted_idx_to_token_idx .data_ptr <int >(), expert_weights. data_ptr (), num_tokens_per_expert .data_ptr <int >(),
164
+ cta_idx_xy_to_batch_idx .data_ptr <int >(), cta_idx_xy_to_mn_limit .data_ptr <int >(),
165
+ num_non_exiting_ctas.data_ptr <int >(), args.mDtypeElt , false , true ,
163
166
static_cast <RoutingMethodType>(routing_method_type), stream);
164
167
165
168
// MoE kernel except routing
0 commit comments