Skip to content

Commit c177aa8

Browse files
committed
initial commit
Signed-off-by: Jhao-Ting Chen <[email protected]>
1 parent 8adaf0b commit c177aa8

File tree

2 files changed

+22
-19
lines changed

2 files changed

+22
-19
lines changed

cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, std::option
3939
int64_t const routing_method_type, MoeRunnerType& moe_runner, int64_t moeConfigIndex)
4040
{
4141
TORCH_CHECK(tensorrt_llm::common::isSM100Family(), "Only SM100f is supported by FP8 block scale MOE");
42-
TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::Float, "routing_logits must be float.");
42+
TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::Float
43+
|| routing_logits.scalar_type() == at::ScalarType::BFloat16,
44+
"routing_logits must be float or bfloat16.");
4345
TORCH_CHECK(routing_logits.dim() == 2, "routing_logits must be 2D.");
4446
TORCH_CHECK(routing_logits.sizes()[0] == hidden_states.sizes()[0],
4547
"routing_logits and hidden_states must have the same number of tokens.");
@@ -69,7 +71,8 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, std::option
6971
else if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::Renormalize
7072
|| static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::RenormalizeNaive)
7173
{
72-
TORCH_CHECK(false, "Don't support this routing method type Renormalize(Naive).");
74+
TORCH_CHECK(top_k <= 8 && top_k > 0,
75+
"Current routing kernel (no groups, renormalize) only supports top_k<=8 && top_k>0.");
7376
}
7477
else if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::Llama4)
7578
{
@@ -89,7 +92,7 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, std::option
8992
= routing_bias.has_value() ? routing_bias.value().scalar_type() : at::ScalarType::BFloat16;
9093
args.mDtypeExpW = routing_bias_dtype == at::ScalarType::Float ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16;
9194

92-
args.routing_logits = routing_logits.data_ptr<float>();
95+
args.routing_logits = routing_logits.data_ptr();
9396
args.routing_bias = routing_bias.has_value() ? routing_bias.value().data_ptr() : nullptr;
9497
args.hidden_states = hidden_states.data_ptr();
9598
args.hidden_states_scale = hidden_states_scale.data_ptr<float>();
@@ -153,13 +156,13 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, std::option
153156

154157
tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::Runner routing_runner(tile_tokens_dim);
155158
auto const& stream = at::cuda::getCurrentCUDAStream(routing_logits.get_device());
156-
routing_runner.run(routing_logits.data_ptr<float>(), args.routing_bias, args.num_tokens, args.num_experts,
157-
args.top_k, args.n_group, args.topk_group, args.local_expert_offset, args.local_num_experts,
158-
args.routed_scaling_factor, expert_indexes.data_ptr<int>(), expert_count_histogram.data_ptr<int>(),
159-
total_num_padded_tokens.data_ptr<int>(), expanded_idx_to_permuted_idx.data_ptr<int>(),
160-
nullptr /*permuted_idx_to_expanded_idx.data_ptr<int>()*/, permuted_idx_to_token_idx.data_ptr<int>(),
161-
expert_weights.data_ptr(), num_tokens_per_expert.data_ptr<int>(), cta_idx_xy_to_batch_idx.data_ptr<int>(),
162-
cta_idx_xy_to_mn_limit.data_ptr<int>(), num_non_exiting_ctas.data_ptr<int>(), args.mDtypeElt, false, true,
159+
routing_runner.run(args.routing_logits, args.routing_bias, args.num_tokens, args.num_experts, args.top_k,
160+
args.n_group, args.topk_group, args.local_expert_offset, args.local_num_experts, args.routed_scaling_factor,
161+
expert_indexes.data_ptr<int>(), expert_count_histogram.data_ptr<int>(), total_num_padded_tokens.data_ptr<int>(),
162+
expanded_idx_to_permuted_idx.data_ptr<int>(), nullptr /*permuted_idx_to_expanded_idx.data_ptr<int>()*/,
163+
permuted_idx_to_token_idx.data_ptr<int>(), expert_weights.data_ptr(), num_tokens_per_expert.data_ptr<int>(),
164+
cta_idx_xy_to_batch_idx.data_ptr<int>(), cta_idx_xy_to_mn_limit.data_ptr<int>(),
165+
num_non_exiting_ctas.data_ptr<int>(), args.mDtypeElt, false, true,
163166
static_cast<RoutingMethodType>(routing_method_type), stream);
164167

165168
// MoE kernel except routing

tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -376,12 +376,12 @@ def __init__(
376376
self,
377377
num_experts: int,
378378
top_k: int,
379-
n_group: int,
380-
topk_group: int,
379+
n_group: Optional[int],
380+
topk_group: Optional[int],
381381
intermediate_size: int,
382382
local_expert_offset: int,
383383
local_num_experts: int,
384-
routed_scaling_factor: float,
384+
routed_scaling_factor: Optional[float],
385385
routing_method_type: int,
386386
):
387387

@@ -542,12 +542,12 @@ def fp8_block_scale_moe_runner(
542542
gemm2_weights_scale: torch.Tensor,
543543
num_experts: int,
544544
top_k: int,
545-
n_group: int,
546-
topk_group: int,
545+
n_group: Optional[int],
546+
topk_group: Optional[int],
547547
intermediate_size: int,
548548
local_expert_offset: int,
549549
local_num_experts: int,
550-
routed_scaling_factor: float,
550+
routed_scaling_factor: Optional[float],
551551
routing_method_type: int,
552552
) -> torch.Tensor:
553553

@@ -598,12 +598,12 @@ def _(
598598
gemm2_weights_scale: torch.Tensor,
599599
num_experts: int,
600600
top_k: int,
601-
n_group: int,
602-
topk_group: int,
601+
n_group: Optional[int],
602+
topk_group: Optional[int],
603603
intermediate_size: int,
604604
local_expert_offset: int,
605605
local_num_experts: int,
606-
routed_scaling_factor: float,
606+
routed_scaling_factor: Optional[float],
607607
routing_method_type: int,
608608
) -> torch.Tensor:
609609
num_tokens = hidden_states.shape[0]

0 commit comments

Comments
 (0)