Skip to content

Commit 1361862

Browse files
danielafrimiDaniel Afrimitlrmchlsmth
authored
[MoE-FP8-modelopt] Add FlashInfer alignment padding for intermediate dimensions (vllm-project#29748)
Signed-off-by: Daniel Afrimi <[email protected]> Signed-off-by: dafrimi <[email protected]> Co-authored-by: Daniel Afrimi <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]>
1 parent 6ec0d8d commit 1361862

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
has_flashinfer,
8282
has_flashinfer_moe,
8383
)
84+
from vllm.utils.math_utils import round_up
8485

8586
if TYPE_CHECKING:
8687
from vllm.model_executor.models.utils import WeightsMapper
@@ -607,6 +608,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
607608
Only supports pre-quantized checkpoints with FP8 weights and scales.
608609
"""
609610

611+
if self.flashinfer_moe_backend is not None:
612+
self._maybe_pad_intermediate_for_flashinfer(layer)
613+
610614
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
611615
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
612616

@@ -684,6 +688,50 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
684688
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
685689
register_moe_scaling_factors(layer)
686690

691+
def _maybe_pad_intermediate_for_flashinfer(self, layer: torch.nn.Module) -> None:
692+
"""Pad intermediate size so FlashInfer kernels' alignment constraints hold.
693+
694+
Some FlashInfer FP8 MoE kernels require the (gated) intermediate size
695+
used for GEMM to be divisible by a small alignment value. When this is
696+
not satisfied (e.g. with certain tensor-parallel sizes), we pad the
697+
gate/up and down projection weights along the intermediate dim.
698+
"""
699+
if not hasattr(layer, "w13_weight") or not hasattr(layer, "w2_weight"):
700+
return
701+
702+
# Current local intermediate size (per partition) is the K dimension of
703+
# the down projection.
704+
num_experts, hidden_size, intermediate = layer.w2_weight.shape
705+
706+
min_alignment = 16
707+
padded_intermediate = round_up(intermediate, min_alignment)
708+
709+
if padded_intermediate == intermediate:
710+
return
711+
712+
logger.info(
713+
"Padding intermediate size from %d to %d for up/down projection weights.",
714+
intermediate,
715+
padded_intermediate,
716+
)
717+
718+
up_mult = 2 if self.moe.is_act_and_mul else 1
719+
padded_gate_up_dim = up_mult * padded_intermediate
720+
721+
# Pad w13 and w12 along its intermediate dimension.
722+
w13 = layer.w13_weight.data
723+
padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size))
724+
padded_w13[:, : w13.shape[1], :] = w13
725+
layer.w13_weight.data = padded_w13
726+
727+
w2 = layer.w2_weight.data
728+
padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate))
729+
padded_w2[:, :, :intermediate] = w2
730+
layer.w2_weight.data = padded_w2
731+
732+
if hasattr(layer, "intermediate_size_per_partition"):
733+
layer.intermediate_size_per_partition = padded_intermediate
734+
687735
def get_fused_moe_quant_config(
688736
self, layer: torch.nn.Module
689737
) -> FusedMoEQuantConfig | None:

0 commit comments

Comments
 (0)