diff --git a/python/sglang/srt/layers/moe/flashinfer_trtllm_moe.py b/python/sglang/srt/layers/moe/flashinfer_trtllm_moe.py index c38126679638..424b3db9bf8e 100644 --- a/python/sglang/srt/layers/moe/flashinfer_trtllm_moe.py +++ b/python/sglang/srt/layers/moe/flashinfer_trtllm_moe.py @@ -189,6 +189,117 @@ def trtllm_fp8_block_scale_routed_moe_wrapper( return trtllm_fp8_block_scale_routed_moe(**kwargs) +def _fake_fp4_block_scale_routed_moe( + topk_ids: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + hidden_states_scale: Optional[torch.Tensor], + gemm1_weights: torch.Tensor, + gemm1_weights_scale: torch.Tensor, + gemm1_bias: Optional[torch.Tensor], + gemm1_alpha: Optional[torch.Tensor], + gemm1_beta: Optional[torch.Tensor], + gemm1_clamp_limit: Optional[torch.Tensor], + gemm2_weights: torch.Tensor, + gemm2_weights_scale: torch.Tensor, + gemm2_bias: Optional[torch.Tensor], + output1_scale_scalar: Optional[torch.Tensor], + output1_scale_gate_scalar: Optional[torch.Tensor], + output2_scale_scalar: Optional[torch.Tensor], + num_experts: int, + top_k: int, + n_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + routed_scaling_factor: Optional[float], + routing_method_type: int = 0, + do_finalize: bool = True, + enable_pdl: Optional[bool] = None, + activation_type: int = 3, + tune_max_num_tokens: int = 8192, +) -> torch.Tensor: + return torch.empty( + (hidden_states.shape[0], gemm2_weights.shape[1]), + dtype=torch.bfloat16, + device=hidden_states.device, + ) + + +@register_custom_op(fake_impl=_fake_fp4_block_scale_routed_moe) +def trtllm_fp4_block_scale_routed_moe_wrapper( + topk_ids: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + hidden_states_scale: Optional[torch.Tensor], + gemm1_weights: torch.Tensor, + gemm1_weights_scale: torch.Tensor, + gemm1_bias: Optional[torch.Tensor], + gemm1_alpha: Optional[torch.Tensor], + gemm1_beta: Optional[torch.Tensor], + gemm1_clamp_limit: Optional[torch.Tensor], + gemm2_weights: torch.Tensor, + gemm2_weights_scale: torch.Tensor, + gemm2_bias: Optional[torch.Tensor], + output1_scale_scalar: Optional[torch.Tensor], + output1_scale_gate_scalar: Optional[torch.Tensor], + output2_scale_scalar: Optional[torch.Tensor], + num_experts: int, + top_k: int, + n_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + routed_scaling_factor: Optional[float], + routing_method_type: int = 0, + do_finalize: bool = True, + enable_pdl: Optional[bool] = None, + activation_type: int = 3, + tune_max_num_tokens: int = 8192, +) -> torch.Tensor: + try: + from flashinfer.fused_moe import trtllm_fp4_block_scale_routed_moe + except ImportError as e: + raise ImportError( + "Can't import trtllm_fp4_block_scale_routed_moe from flashinfer. " + "Please check flashinfer version." + ) from e + + return trtllm_fp4_block_scale_routed_moe( + topk_ids=topk_ids, + routing_bias=routing_bias, + hidden_states=hidden_states, + hidden_states_scale=hidden_states_scale, + gemm1_weights=gemm1_weights, + gemm1_weights_scale=gemm1_weights_scale, + gemm1_bias=gemm1_bias, + gemm1_alpha=gemm1_alpha, + gemm1_beta=gemm1_beta, + gemm1_clamp_limit=gemm1_clamp_limit, + gemm2_weights=gemm2_weights, + gemm2_weights_scale=gemm2_weights_scale, + gemm2_bias=gemm2_bias, + output1_scale_scalar=output1_scale_scalar, + output1_scale_gate_scalar=output1_scale_gate_scalar, + output2_scale_scalar=output2_scale_scalar, + num_experts=num_experts, + top_k=top_k, + n_group=n_group, + topk_group=topk_group, + intermediate_size=intermediate_size, + local_expert_offset=local_expert_offset, + local_num_experts=local_num_experts, + routed_scaling_factor=routed_scaling_factor, + routing_method_type=routing_method_type, + do_finalize=do_finalize, + enable_pdl=enable_pdl, + activation_type=activation_type, + tune_max_num_tokens=tune_max_num_tokens, + )[0] + + def _fake_fp8_per_tensor_scale_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py index e568579bb7e8..0a291bea5190 100644 --- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py +++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py @@ -14,6 +14,7 @@ ) from sglang.srt.layers.dp_attention import is_allocation_symmetric from sglang.srt.layers.moe.flashinfer_trtllm_moe import ( + trtllm_fp4_block_scale_routed_moe_wrapper, trtllm_fp8_block_scale_moe_wrapper, trtllm_fp8_block_scale_routed_moe_wrapper, trtllm_fp8_per_tensor_scale_moe_wrapper, @@ -275,15 +276,15 @@ def align_fp4_moe_weights_for_flashinfer_trtllm(layer: Module) -> None: w13_weight.size(0), # num_experts ) - # Set flashinfer parameters + # Set flashinfer parameters in-place + copy_or_rebind_param(layer, "w13_weight", gemm1_weights_fp4_shuffled.contiguous()) + copy_or_rebind_param(layer, "w2_weight", gemm2_weights_fp4_shuffled.contiguous()) copy_or_rebind_param( - layer, "gemm1_weights_fp4_shuffled", gemm1_weights_fp4_shuffled + layer, "w13_weight_scale", gemm1_scales_fp4_shuffled.contiguous() ) copy_or_rebind_param( - layer, "gemm2_weights_fp4_shuffled", gemm2_weights_fp4_shuffled + layer, "w2_weight_scale", gemm2_scales_fp4_shuffled.contiguous() ) - copy_or_rebind_param(layer, "gemm1_scales_fp4_shuffled", gemm1_scales_fp4_shuffled) - copy_or_rebind_param(layer, "gemm2_scales_fp4_shuffled", gemm2_scales_fp4_shuffled) # Compute additional scaling factor needed for TRT-LLM w2_input_scale_quant = cast(torch.Tensor, layer.w2_input_scale_quant) @@ -294,14 +295,6 @@ def align_fp4_moe_weights_for_flashinfer_trtllm(layer: Module) -> None: (w2_input_scale_quant * g1_alphas).to(torch.float32), ) - # Clean up weights that won't be used by TRT-LLM - del ( - layer.w2_weight, - layer.w2_weight_scale, - layer.w13_weight, - layer.w13_weight_scale, - ) - @dataclass class FlashInferTrtllmFp8MoeQuantInfo(MoeQuantInfo): @@ -560,11 +553,10 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( class FlashInferTrtllmFp4MoeQuantInfo(MoeQuantInfo): """Quantization payload consumed by FlashInfer TRT-LLM FP4 MoE kernels.""" - # Shuffled FP4 weights (processed by align_fp4_moe_weights_for_flashinfer_trtllm) - gemm1_weights_fp4_shuffled: torch.Tensor - gemm2_weights_fp4_shuffled: torch.Tensor - gemm1_scales_fp4_shuffled: torch.Tensor - gemm2_scales_fp4_shuffled: torch.Tensor + w13_weight: torch.Tensor + w2_weight: torch.Tensor + w13_weight_scale: torch.Tensor + w2_weight_scale: torch.Tensor # Scaling factors g1_scale_c: torch.Tensor @@ -616,6 +608,7 @@ def fused_experts_none_to_flashinfer_trtllm_fp4( dispatch_output: StandardDispatchOutput, quant_info: FlashInferTrtllmFp4MoeQuantInfo, runner_config: MoeRunnerConfig, + use_routed_topk: bool = False, ) -> StandardCombineInput: """FlashInfer TRTLLM FP4 MoE forward pass. @@ -633,10 +626,19 @@ def fused_experts_none_to_flashinfer_trtllm_fp4( hidden_states = dispatch_output.hidden_states topk_output = dispatch_output.topk_output - assert TopKOutputChecker.format_is_bypassed(topk_output) + if TopKOutputChecker.format_is_bypassed(topk_output): + router_logits = topk_output.router_logits + topk_config = topk_output.topk_config + correction_bias = ( + None + if topk_config.correction_bias is None + else topk_config.correction_bias.to(hidden_states.dtype) + ) + else: + router_logits = None + topk_config = None + correction_bias = None - router_logits = topk_output.router_logits - topk_config = topk_output.topk_config routing_method_type = quant_info.routing_method_type # Quantize hidden states to FP4 @@ -644,16 +646,6 @@ def fused_experts_none_to_flashinfer_trtllm_fp4( hidden_states, quant_info.w13_input_scale_quant ) - # DeepSeekV3 style routing requires float32 router logits - if routing_method_type == RoutingMethodType.DeepSeekV3: - router_logits = router_logits.to(torch.float32) - - correction_bias = ( - None - if topk_config.correction_bias is None - else topk_config.correction_bias.to(hidden_states.dtype) - ) - with use_symmetric_memory(get_tp_group(), disabled=not is_allocation_symmetric()): num_tokens = hs_fp4.shape[0] hidden_size = ( @@ -663,46 +655,93 @@ def fused_experts_none_to_flashinfer_trtllm_fp4( num_tokens, hidden_size, dtype=torch.bfloat16, device=hs_fp4.device ) - result = trtllm_fp4_block_scale_moe( - routing_logits=router_logits, - routing_bias=correction_bias, - hidden_states=hs_fp4, - hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).reshape( - *hs_scale_linear.shape[:-1], -1 - ), - gemm1_weights=quant_info.gemm1_weights_fp4_shuffled, - gemm1_weights_scale=quant_info.gemm1_scales_fp4_shuffled.view( - torch.float8_e4m3fn - ), - gemm1_bias=None, - gemm1_alpha=None, - gemm1_beta=None, - gemm1_clamp_limit=None, - gemm2_weights=quant_info.gemm2_weights_fp4_shuffled, - gemm2_weights_scale=quant_info.gemm2_scales_fp4_shuffled.view( - torch.float8_e4m3fn - ), - gemm2_bias=None, - output1_scale_scalar=quant_info.g1_scale_c, - output1_scale_gate_scalar=quant_info.g1_alphas, - output2_scale_scalar=quant_info.g2_alphas, - num_experts=quant_info.global_num_experts, - top_k=topk_config.top_k, - n_group=topk_config.num_expert_group, - topk_group=topk_config.topk_group, - intermediate_size=quant_info.intermediate_size_per_partition, - local_expert_offset=quant_info.local_expert_offset, - local_num_experts=quant_info.local_num_experts, - routed_scaling_factor=runner_config.routed_scaling_factor, - routing_method_type=( - routing_method_type - if routing_method_type is not None - else RoutingMethodType.Default - ), - do_finalize=True, - tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]), - output=symm_output, - )[0] + if use_routed_topk: + assert ( + runner_config.top_k is not None + ), "runner_config.top_k is required for flashinfer_trtllm_routed." + assert TopKOutputChecker.format_is_standard(topk_output) + packed_topk_ids = _pack_topk_for_flashinfer_routed( + topk_ids=topk_output.topk_ids, + topk_weights=topk_output.topk_weights, + ) + + result = trtllm_fp4_block_scale_routed_moe_wrapper( + topk_ids=packed_topk_ids, + routing_bias=None, + hidden_states=hs_fp4, + hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).reshape( + *hs_scale_linear.shape[:-1], -1 + ), + gemm1_weights=quant_info.w13_weight, + gemm1_weights_scale=quant_info.w13_weight_scale.view(torch.float8_e4m3fn), + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=quant_info.w2_weight, + gemm2_weights_scale=quant_info.w2_weight_scale.view(torch.float8_e4m3fn), + gemm2_bias=None, + output1_scale_scalar=quant_info.g1_scale_c, + output1_scale_gate_scalar=quant_info.g1_alphas, + output2_scale_scalar=quant_info.g2_alphas, + num_experts=quant_info.global_num_experts, + top_k=runner_config.top_k, + n_group=None, + topk_group=None, + intermediate_size=quant_info.intermediate_size_per_partition, + local_expert_offset=quant_info.local_expert_offset, + local_num_experts=quant_info.local_num_experts, + routed_scaling_factor=runner_config.routed_scaling_factor, + routing_method_type=( + RoutingMethodType.TopK + if routing_method_type == RoutingMethodType.DeepSeekV3 + else routing_method_type + ), + tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]), + ) + else: + assert TopKOutputChecker.format_is_bypassed(topk_output) + + result = trtllm_fp4_block_scale_moe( + routing_logits=( + router_logits.to(torch.float32) + if routing_method_type == RoutingMethodType.DeepSeekV3 + else router_logits + ), + routing_bias=correction_bias, + hidden_states=hs_fp4, + hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).reshape( + *hs_scale_linear.shape[:-1], -1 + ), + gemm1_weights=quant_info.w13_weight, + gemm1_weights_scale=quant_info.w13_weight_scale.view(torch.float8_e4m3fn), + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=quant_info.w2_weight, + gemm2_weights_scale=quant_info.w2_weight_scale.view(torch.float8_e4m3fn), + gemm2_bias=None, + output1_scale_scalar=quant_info.g1_scale_c, + output1_scale_gate_scalar=quant_info.g1_alphas, + output2_scale_scalar=quant_info.g2_alphas, + num_experts=quant_info.global_num_experts, + top_k=topk_config.top_k, + n_group=topk_config.num_expert_group, + topk_group=topk_config.topk_group, + intermediate_size=quant_info.intermediate_size_per_partition, + local_expert_offset=quant_info.local_expert_offset, + local_num_experts=quant_info.local_num_experts, + routed_scaling_factor=runner_config.routed_scaling_factor, + routing_method_type=( + routing_method_type + if routing_method_type is not None + else RoutingMethodType.Default + ), + do_finalize=True, + tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]), + output=symm_output, + )[0] return StandardCombineInput(hidden_states=result) @@ -858,6 +897,13 @@ def fused_experts_none_to_flashinfer_trtllm_routed( quant_info: MoeQuantInfo, runner_config: MoeRunnerConfig, ) -> StandardCombineInput: + if isinstance(quant_info, FlashInferTrtllmFp4MoeQuantInfo): + return fused_experts_none_to_flashinfer_trtllm_fp4( + dispatch_output, + quant_info, + runner_config, + use_routed_topk=True, + ) if isinstance(quant_info, FlashInferTrtllmFp8MoeQuantInfo): return fused_experts_none_to_flashinfer_trtllm_fp8( dispatch_output, diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py index 5898a078dbba..6b285809ba16 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py @@ -20,6 +20,7 @@ from sglang.srt.layers.quantization.utils import ( prepare_static_weights_for_trtllm_fp4_moe, reorder_w1w3_to_w3w1, + replace_parameter, swizzle_blockscale, ) from sglang.srt.utils import next_power_of_2, set_weight_attrs @@ -257,30 +258,16 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) logger.debug("Finished shuffling weights for TRT-LLM MOE") - layer.gemm1_weights_fp4_shuffled = torch.nn.Parameter( - gemm1_weights_fp4_shuffled, requires_grad=False - ) - layer.gemm2_weights_fp4_shuffled = torch.nn.Parameter( - gemm2_weights_fp4_shuffled, requires_grad=False - ) - layer.gemm1_scales_fp4_shuffled = torch.nn.Parameter( - gemm1_scales_fp4_shuffled, requires_grad=False - ) - layer.gemm2_scales_fp4_shuffled = torch.nn.Parameter( - gemm2_scales_fp4_shuffled, requires_grad=False - ) + replace_parameter(layer, "w13_weight", gemm1_weights_fp4_shuffled) + replace_parameter(layer, "w2_weight", gemm2_weights_fp4_shuffled) + replace_parameter(layer, "w13_weight_scale", gemm1_scales_fp4_shuffled) + replace_parameter(layer, "w2_weight_scale", gemm2_scales_fp4_shuffled) # Additional parameter needed for TRT-LLM layer.g1_scale_c = torch.nn.Parameter( (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), requires_grad=False, ) - - # Clean up weights that won't be used by TRT-LLM - del layer.w2_weight - del layer.w2_weight_scale - del layer.w13_weight - del layer.w13_weight_scale else: # swizzle weight scales layer.w13_weight_scale = torch.nn.Parameter( @@ -370,18 +357,14 @@ def apply_weights( routing_bias=correction_bias, hidden_states=hs_fp4, hidden_states_scale=hs_scale, - gemm1_weights=layer.gemm1_weights_fp4_shuffled, - gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.view( - torch.float8_e4m3fn - ), + gemm1_weights=layer.w13_weight, + gemm1_weights_scale=layer.w13_weight_scale.view(torch.float8_e4m3fn), gemm1_bias=None, gemm1_alpha=None, gemm1_beta=None, gemm1_clamp_limit=None, - gemm2_weights=layer.gemm2_weights_fp4_shuffled, - gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.view( - torch.float8_e4m3fn - ), + gemm2_weights=layer.w2_weight, + gemm2_weights_scale=layer.w2_weight_scale.view(torch.float8_e4m3fn), gemm2_bias=None, output1_scale_scalar=layer.g1_scale_c, output1_scale_gate_scalar=layer.g1_alphas, diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index c0d9958e45ee..31fb6c511f3a 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1534,6 +1534,7 @@ def __init__(self, quant_config: ModelOptFp4Config): ) self.enable_flashinfer_trtllm_moe = ( get_moe_runner_backend().is_flashinfer_trtllm() + or get_moe_runner_backend().is_flashinfer_trtllm_routed() ) self._cache_permute_indices = {} @@ -1788,7 +1789,10 @@ def _slice_scale(w): ("w2", layer.w2_weight_scale), ]: # For NVFP4 TRTLLM we require one scale per 16 inputs (last dim == expected_blocks[name]). - if get_moe_runner_backend().is_flashinfer_trtllm(): + if ( + get_moe_runner_backend().is_flashinfer_trtllm() + or get_moe_runner_backend().is_flashinfer_trtllm_routed() + ): expected_blocks = { "w13": layer.w13_weight.shape[2] * 2 // block_size, "w2": layer.w2_weight.shape[2] * 2 // block_size, @@ -1900,9 +1904,17 @@ def create_moe_runner( self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig ): self.moe_runner_config = moe_runner_config - if get_moe_runner_backend().is_flashinfer_trtllm(): + if ( + get_moe_runner_backend().is_flashinfer_trtllm() + or get_moe_runner_backend().is_flashinfer_trtllm_routed() + ): self.runner = MoeRunner( - MoeRunnerBackend.FLASHINFER_TRTLLM, moe_runner_config + ( + MoeRunnerBackend.FLASHINFER_TRTLLM_ROUTED + if get_moe_runner_backend().is_flashinfer_trtllm_routed() + else MoeRunnerBackend.FLASHINFER_TRTLLM + ), + moe_runner_config, ) def apply( @@ -1922,9 +1934,8 @@ def apply( ), f"{activation=} missing from {ACT_STR_TO_TYPE_MAP.keys()=}" moe_runner_config = self.moe_runner_config - # FlashInfer TRTLLM FP4 path - layer has shuffled weights only when - # backend is flashinfer_trtllm - if hasattr(layer, "gemm1_weights_fp4_shuffled"): + # FlashInfer TRTLLM FP4 path + if self.enable_flashinfer_trtllm_moe and hasattr(layer, "g1_scale_c"): from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import ( FlashInferTrtllmFp4MoeQuantInfo, ) @@ -1936,10 +1947,10 @@ def apply( ) quant_info = FlashInferTrtllmFp4MoeQuantInfo( - gemm1_weights_fp4_shuffled=layer.gemm1_weights_fp4_shuffled.data, - gemm2_weights_fp4_shuffled=layer.gemm2_weights_fp4_shuffled.data, - gemm1_scales_fp4_shuffled=layer.gemm1_scales_fp4_shuffled.data, - gemm2_scales_fp4_shuffled=layer.gemm2_scales_fp4_shuffled.data, + w13_weight=layer.w13_weight.data, + w2_weight=layer.w2_weight.data, + w13_weight_scale=layer.w13_weight_scale.data, + w2_weight_scale=layer.w2_weight_scale.data, g1_scale_c=layer.g1_scale_c.data, g1_alphas=layer.g1_alphas.data, g2_alphas=layer.g2_alphas.data, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 646edb414452..636bce1d1bc4 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2698,8 +2698,9 @@ def _handle_moe_kernel_config(self): assert self.quantization in [ "fp8", "mxfp8", + "modelopt_fp4", None, - ], f"Invalid quantization '{self.quantization}'. \nFlashInfer TRTLLM routed MOE supports only: 'fp8', 'mxfp8', or bfloat16 (None)." + ], f"Invalid quantization '{self.quantization}'. \nFlashInfer TRTLLM routed MOE supports only: 'fp8', 'mxfp8', 'modelopt_fp4', or bfloat16 (None)." self.disable_shared_experts_fusion = True logger.warning( "FlashInfer TRTLLM routed MoE is enabled. --disable-shared-experts-fusion is automatically set." diff --git a/test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py b/test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py index b63447a60cd5..9c7260b19966 100644 --- a/test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py +++ b/test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py @@ -157,6 +157,49 @@ def test_gsm8k(self): self.assertGreater(metrics["score"], 0.93) +class FlashinferTrtllmGenMoeBackendNVFP4Base: + backend = None + + @classmethod + def setUpClass(cls): + cls.model = "nvidia/Qwen3-30B-A3B-NVFP4" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + env={**os.environ, "SGLANG_ENABLE_JIT_DEEPGEMM": "False"}, + other_args=[ + "--moe-runner-backend", + cls.backend, + "--tp-size", + "4", + "--ep-size", + "4", + "--mem-fraction-static", + "0.7", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["score"], 0.89) + + class TestFlashinferTrtllmGenMoeBackendFP8( FlashinferTrtllmGenMoeBackendFP8Base, CustomTestCase ): @@ -175,6 +218,12 @@ class TestFlashinferTrtllmGenMoeBackendBF16( backend = "flashinfer_trtllm" +class TestFlashinferTrtllmGenMoeBackendNVFP4( + FlashinferTrtllmGenMoeBackendNVFP4Base, CustomTestCase +): + backend = "flashinfer_trtllm" + + class TestFlashinferTrtllmGenMoeBackendFP8Routed( FlashinferTrtllmGenMoeBackendFP8Base, CustomTestCase ): @@ -193,5 +242,11 @@ class TestFlashinferTrtllmGenMoeBackendBF16Routed( backend = "flashinfer_trtllm_routed" +class TestFlashinferTrtllmGenMoeBackendNVFP4Routed( + FlashinferTrtllmGenMoeBackendNVFP4Base, CustomTestCase +): + backend = "flashinfer_trtllm_routed" + + if __name__ == "__main__": unittest.main() diff --git a/test/registered/rl/test_update_weights_from_disk_mxfp8.py b/test/registered/rl/test_update_weights_from_disk_blackwell.py similarity index 69% rename from test/registered/rl/test_update_weights_from_disk_mxfp8.py rename to test/registered/rl/test_update_weights_from_disk_blackwell.py index 504eb1f1d703..eaabc3217be3 100644 --- a/test/registered/rl/test_update_weights_from_disk_mxfp8.py +++ b/test/registered/rl/test_update_weights_from_disk_blackwell.py @@ -15,37 +15,43 @@ ) -class TestServerUpdateWeightsFromDiskMXFP8(CustomTestCase): - model = "zianglih/Qwen3-30B-A3B-Instruct-2507-MXFP8-last-8-BF16" +class UpdateWeightsFromDiskModelBase: + model = None base_url = DEFAULT_URL_FOR_TEST request_timeout = 120 update_timeout = 240 + launch_env = None decode_payload = { "text": "The capital of France is", "sampling_params": {"temperature": 0, "max_new_tokens": 16}, } - backend_test_suites = ( - { - "fp8_gemm_backend": "flashinfer_trtllm", - "moe_runner_backend": "flashinfer_trtllm_routed", - }, + backend_test_suites = () + update_test_suites = ( + {"flush_cache": True, "abort_all_requests": False}, + {"flush_cache": False, "abort_all_requests": False}, ) - def _launch_server(self, fp8_gemm_backend, moe_runner_backend): + @classmethod + def setUpClass(cls): + super().setUpClass() + if cls.model is None: + raise NotImplementedError("Subclass must set 'model' attribute") + if not cls.backend_test_suites: + raise NotImplementedError( + "Subclass must set non-empty 'backend_test_suites'" + ) + + def _launch_server(self, backend_test_suite): + launch_kwargs = {} + if self.launch_env is not None: + launch_kwargs["env"] = self.launch_env + other_args = backend_test_suite.get("other_args") return popen_launch_server( self.model, self.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--base-gpu-id", - "0", - "--tp-size", - "4", - "--fp8-gemm-backend", - fp8_gemm_backend, - "--moe-runner-backend", - moe_runner_backend, - ], + other_args=other_args, + **launch_kwargs, ) def _get_json(self, endpoint, timeout=None): @@ -119,30 +125,19 @@ def _run_update_weights( timeout=self.update_timeout, ) - def test_parameterized_update_weights_mxfp8(self): - update_test_suites = ( - {"flush_cache": True, "abort_all_requests": False}, - {"flush_cache": False, "abort_all_requests": False}, - ) + def test_parameterized_update_weights_from_disk(self): for backend_test_suite in self.backend_test_suites: - with self.subTest(**backend_test_suite): - process = self._launch_server( - backend_test_suite["fp8_gemm_backend"], - backend_test_suite["moe_runner_backend"], - ) + case_name = backend_test_suite.get("name", "default") + with self.subTest(model=self.model, case_name=case_name): + process = self._launch_server(backend_test_suite) try: origin_model_path = self._get_model_info() self.assertEqual(origin_model_path, self.model) self._assert_non_empty_decode() baseline_sig = self._get_decode_logprob_signature() - for update_test_suite in update_test_suites: - with self.subTest( - fp8_gemm_backend=backend_test_suite["fp8_gemm_backend"], - moe_runner_backend=backend_test_suite["moe_runner_backend"], - flush_cache=update_test_suite["flush_cache"], - abort_all_requests=update_test_suite["abort_all_requests"], - ): + for update_test_suite in self.update_test_suites: + with self.subTest(case_name=case_name, **update_test_suite): ret = self._run_update_weights( self.model, flush_cache=update_test_suite["flush_cache"], @@ -161,5 +156,47 @@ def test_parameterized_update_weights_mxfp8(self): kill_process_tree(process.pid) +class TestServerUpdateWeightsFromDiskMXFP8( + UpdateWeightsFromDiskModelBase, CustomTestCase +): + model = "zianglih/Qwen3-30B-A3B-Instruct-2507-MXFP8-last-8-BF16" + backend_test_suites = ( + { + "name": "flashinfer_trtllm_routed_mxfp8", + "other_args": ( + "--base-gpu-id", + "0", + "--tp-size", + "4", + "--fp8-gemm-backend", + "flashinfer_trtllm", + "--moe-runner-backend", + "flashinfer_trtllm_routed", + ), + }, + ) + + +class TestServerUpdateWeightsFromDiskNVFP4( + UpdateWeightsFromDiskModelBase, CustomTestCase +): + model = "nvidia/Qwen3-30B-A3B-NVFP4" + backend_test_suites = ( + { + "name": "flashinfer_trtllm_nvfp4", + "other_args": ( + "--base-gpu-id", + "0", + "--tp-size", + "4", + "--fp4-gemm-backend", + "flashinfer_trtllm", + "--moe-runner-backend", + "flashinfer_trtllm_routed", + ), + }, + ) + + if __name__ == "__main__": unittest.main()