|
81 | 81 | has_flashinfer, |
82 | 82 | has_flashinfer_moe, |
83 | 83 | ) |
| 84 | +from vllm.utils.math_utils import round_up |
84 | 85 |
|
85 | 86 | if TYPE_CHECKING: |
86 | 87 | from vllm.model_executor.models.utils import WeightsMapper |
@@ -607,6 +608,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: |
607 | 608 | Only supports pre-quantized checkpoints with FP8 weights and scales. |
608 | 609 | """ |
609 | 610 |
|
| 611 | + if self.flashinfer_moe_backend is not None: |
| 612 | + self._maybe_pad_intermediate_for_flashinfer(layer) |
| 613 | + |
610 | 614 | layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False) |
611 | 615 | layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) |
612 | 616 |
|
@@ -684,6 +688,50 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: |
684 | 688 | rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) |
685 | 689 | register_moe_scaling_factors(layer) |
686 | 690 |
|
| 691 | + def _maybe_pad_intermediate_for_flashinfer(self, layer: torch.nn.Module) -> None: |
| 692 | + """Pad intermediate size so FlashInfer kernels' alignment constraints hold. |
| 693 | +
|
| 694 | + Some FlashInfer FP8 MoE kernels require the (gated) intermediate size |
| 695 | + used for GEMM to be divisible by a small alignment value. When this is |
| 696 | + not satisfied (e.g. with certain tensor-parallel sizes), we pad the |
| 697 | + gate/up and down projection weights along the intermediate dim. |
| 698 | + """ |
| 699 | + if not hasattr(layer, "w13_weight") or not hasattr(layer, "w2_weight"): |
| 700 | + return |
| 701 | + |
| 702 | + # Current local intermediate size (per partition) is the K dimension of |
| 703 | + # the down projection. |
| 704 | + num_experts, hidden_size, intermediate = layer.w2_weight.shape |
| 705 | + |
| 706 | + min_alignment = 16 |
| 707 | + padded_intermediate = round_up(intermediate, min_alignment) |
| 708 | + |
| 709 | + if padded_intermediate == intermediate: |
| 710 | + return |
| 711 | + |
| 712 | + logger.info( |
| 713 | + "Padding intermediate size from %d to %d for up/down projection weights.", |
| 714 | + intermediate, |
| 715 | + padded_intermediate, |
| 716 | + ) |
| 717 | + |
| 718 | + up_mult = 2 if self.moe.is_act_and_mul else 1 |
| 719 | + padded_gate_up_dim = up_mult * padded_intermediate |
| 720 | + |
| 721 | + # Pad w13 and w12 along its intermediate dimension. |
| 722 | + w13 = layer.w13_weight.data |
| 723 | + padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size)) |
| 724 | + padded_w13[:, : w13.shape[1], :] = w13 |
| 725 | + layer.w13_weight.data = padded_w13 |
| 726 | + |
| 727 | + w2 = layer.w2_weight.data |
| 728 | + padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate)) |
| 729 | + padded_w2[:, :, :intermediate] = w2 |
| 730 | + layer.w2_weight.data = padded_w2 |
| 731 | + |
| 732 | + if hasattr(layer, "intermediate_size_per_partition"): |
| 733 | + layer.intermediate_size_per_partition = padded_intermediate |
| 734 | + |
687 | 735 | def get_fused_moe_quant_config( |
688 | 736 | self, layer: torch.nn.Module |
689 | 737 | ) -> FusedMoEQuantConfig | None: |
|
0 commit comments