@@ -39,7 +39,9 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, std::option
3939 int64_t const routing_method_type, MoeRunnerType& moe_runner, int64_t moeConfigIndex)
4040{
4141 TORCH_CHECK (tensorrt_llm::common::isSM100Family (), " Only SM100f is supported by FP8 block scale MOE" );
42- TORCH_CHECK (routing_logits.scalar_type () == at::ScalarType::Float, " routing_logits must be float." );
42+ TORCH_CHECK (routing_logits.scalar_type () == at::ScalarType::Float
43+ || routing_logits.scalar_type () == at::ScalarType::BFloat16,
44+ " routing_logits must be float or bfloat16." );
4345 TORCH_CHECK (routing_logits.dim () == 2 , " routing_logits must be 2D." );
4446 TORCH_CHECK (routing_logits.sizes ()[0 ] == hidden_states.sizes ()[0 ],
4547 " routing_logits and hidden_states must have the same number of tokens." );
@@ -69,7 +71,8 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, std::option
6971 else if (static_cast <RoutingMethodType>(routing_method_type) == RoutingMethodType::Renormalize
7072 || static_cast <RoutingMethodType>(routing_method_type) == RoutingMethodType::RenormalizeNaive)
7173 {
72- TORCH_CHECK (false , " Don't support this routing method type Renormalize(Naive)." );
74+ TORCH_CHECK (top_k <= 8 && top_k > 0 ,
75+ " Current routing kernel (no groups, renormalize) only supports top_k<=8 && top_k>0." );
7376 }
7477 else if (static_cast <RoutingMethodType>(routing_method_type) == RoutingMethodType::Llama4)
7578 {
@@ -89,7 +92,7 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, std::option
8992 = routing_bias.has_value () ? routing_bias.value ().scalar_type () : at::ScalarType::BFloat16;
9093 args.mDtypeExpW = routing_bias_dtype == at::ScalarType::Float ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16;
9194
92- args.routing_logits = routing_logits.data_ptr < float > ();
95+ args.routing_logits = routing_logits.data_ptr ();
9396 args.routing_bias = routing_bias.has_value () ? routing_bias.value ().data_ptr () : nullptr ;
9497 args.hidden_states = hidden_states.data_ptr ();
9598 args.hidden_states_scale = hidden_states_scale.data_ptr <float >();
@@ -153,13 +156,13 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, std::option
153156
154157 tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::Runner routing_runner (tile_tokens_dim);
155158 auto const & stream = at::cuda::getCurrentCUDAStream (routing_logits.get_device ());
156- routing_runner.run (routing_logits. data_ptr < float >() , args.routing_bias , args.num_tokens , args.num_experts ,
157- args.top_k , args. n_group , args.topk_group , args.local_expert_offset , args.local_num_experts ,
158- args. routed_scaling_factor , expert_indexes .data_ptr <int >(), expert_count_histogram .data_ptr <int >(),
159- total_num_padded_tokens .data_ptr <int >(), expanded_idx_to_permuted_idx .data_ptr <int >(),
160- nullptr /* permuted_idx_to_expanded_idx .data_ptr<int>()*/ , permuted_idx_to_token_idx .data_ptr <int >(),
161- expert_weights .data_ptr (), num_tokens_per_expert. data_ptr <int >(), cta_idx_xy_to_batch_idx .data_ptr <int >(),
162- cta_idx_xy_to_mn_limit. data_ptr < int >(), num_non_exiting_ctas.data_ptr <int >(), args.mDtypeElt , false , true ,
159+ routing_runner.run (args. routing_logits , args.routing_bias , args.num_tokens , args.num_experts , args. top_k ,
160+ args.n_group , args.topk_group , args.local_expert_offset , args.local_num_experts , args. routed_scaling_factor ,
161+ expert_indexes. data_ptr < int >(), expert_count_histogram .data_ptr <int >(), total_num_padded_tokens .data_ptr <int >(),
162+ expanded_idx_to_permuted_idx .data_ptr <int >(), nullptr /* permuted_idx_to_expanded_idx .data_ptr<int>()*/ ,
163+ permuted_idx_to_token_idx .data_ptr <int >(), expert_weights. data_ptr (), num_tokens_per_expert .data_ptr <int >(),
164+ cta_idx_xy_to_batch_idx .data_ptr <int >(), cta_idx_xy_to_mn_limit .data_ptr <int >(),
165+ num_non_exiting_ctas.data_ptr <int >(), args.mDtypeElt , false , true ,
163166 static_cast <RoutingMethodType>(routing_method_type), stream);
164167
165168 // MoE kernel except routing
0 commit comments