Skip to content

Commit a9f13e5

Browse files
authored
[Kernel][FusedMoE] Add support for bias (#1167)
Signed-off-by: Kyuyeun Kim <[email protected]>
1 parent 691ce91 commit a9f13e5

File tree

6 files changed

+504
-98
lines changed

6 files changed

+504
-98
lines changed

tests/kernels/fused_moe_v1_test.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,15 @@
1010
jax.config.parse_flags_with_absl()
1111

1212

13+
def cdiv(a, b):
14+
assert b != 0
15+
return (a + b - 1) // b
16+
17+
18+
def align_to(x, a):
19+
return cdiv(x, a) * a
20+
21+
1322
def gen_moe_inputs(
1423
dtype,
1524
top_k,
@@ -19,33 +28,49 @@ def gen_moe_inputs(
1928
num_tokens,
2029
*,
2130
seed=1234,
31+
has_bias=False,
2232
):
2333
key = jax.random.key(seed)
24-
k0, k1, k2, k4, k5 = jax.random.split(key, 5)
34+
k0, k1, k2, k3, k4, k5, k6 = jax.random.split(key, 7)
35+
2536
a = jax.random.normal(k0, (num_tokens, hidden_size),
2637
dtype=jnp.float32).astype(dtype) / 10
38+
2739
w1 = (jax.random.normal(
2840
k1,
2941
(num_experts, 2, hidden_size, intermediate_size),
3042
dtype=jnp.float32,
3143
) / 10).astype(dtype)
3244
w2 = (jax.random.normal(k2, (num_experts, intermediate_size, hidden_size),
3345
dtype=jnp.float32) / 10).astype(dtype)
46+
47+
if has_bias:
48+
b1 = (jax.random.normal(k3, (num_experts, 2, intermediate_size),
49+
dtype=jnp.float32) / 10).astype(dtype)
50+
b2 = (jax.random.normal(k4, (num_experts, hidden_size),
51+
dtype=jnp.float32) / 10).astype(dtype)
52+
else:
53+
b1 = b2 = None
54+
3455
gating_output = (
35-
jax.random.normal(k4, (num_tokens, num_experts), dtype=jnp.float32) +
56+
jax.random.normal(k5, (num_tokens, num_experts), dtype=jnp.float32) +
3657
jnp.arange(num_tokens * num_experts, dtype=jnp.float32).reshape(
3758
num_tokens, num_experts) / 100)
59+
3860
# To generate unique top-k!
39-
top_k_indices = jax.random.randint(k5, (num_tokens, top_k),
61+
top_k_indices = jax.random.randint(k6, (num_tokens, top_k),
4062
minval=0,
4163
maxval=num_experts - 1,
4264
dtype=jnp.int32)
65+
4366
one_hot = (jnp.sum(
4467
jax.nn.one_hot(top_k_indices, num_experts, dtype=jnp.float32),
4568
axis=1,
4669
) * 30)
70+
4771
gating_output = (gating_output + one_hot).astype(dtype)
48-
return a, w1, w2, gating_output
72+
73+
return a, w1, w2, b1, b2, gating_output
4974

5075

5176
def sub_channel_quantize(x, quant_dtype, wsz=256):
@@ -104,18 +129,19 @@ def _test_moe(
104129
act_fn="silu",
105130
w_dtype=None,
106131
subc_quant_wsz=None,
107-
use_benchmark_baseline=False,
132+
has_bias=False,
108133
atol=2e-1,
109134
rtol=2e-1,
110135
):
111-
a, w1, w2, gating_output = gen_moe_inputs(
136+
a, w1, w2, b1, b2, gating_output = gen_moe_inputs(
112137
dtype,
113138
top_k,
114139
num_experts,
115140
hidden_size,
116141
intermediate_size,
117142
num_tokens,
118143
seed=seed,
144+
has_bias=has_bias,
119145
)
120146
w1_scale = None
121147
w2_scale = None
@@ -137,6 +163,8 @@ def _test_moe(
137163
subc_quant_wsz=subc_quant_wsz,
138164
w1_scale=w1_scale,
139165
w2_scale=w2_scale,
166+
b1=b1,
167+
b2=b2,
140168
bt=bt,
141169
bf=bf,
142170
bd1=bd1,
@@ -152,6 +180,8 @@ def _test_moe(
152180
w2,
153181
gating_output,
154182
top_k,
183+
b1=b1,
184+
b2=b2,
155185
renormalize_topk_logits=renormalize_topk_logits,
156186
activation=act_fn,
157187
subc_quant_wsz=subc_quant_wsz,
@@ -312,6 +342,33 @@ def test_sub_channel_quantization(self, w_dtype):
312342
bd2c=256,
313343
)
314344

345+
def test_bias(self):
346+
dtype = jnp.bfloat16
347+
top_k = 8
348+
num_experts = 128
349+
hidden_size = 1024
350+
intermediate_size = 1024
351+
num_tokens = 8 * 32
352+
self._test_moe(
353+
dtype=dtype,
354+
top_k=top_k,
355+
num_experts=num_experts,
356+
hidden_size=hidden_size,
357+
intermediate_size=intermediate_size,
358+
num_tokens=num_tokens,
359+
seed=1234,
360+
renormalize_topk_logits=False,
361+
has_bias=True,
362+
bt=32,
363+
bf=512,
364+
bd1=512,
365+
bd2=512,
366+
btc=32,
367+
bfc=256,
368+
bd1c=256,
369+
bd2c=256,
370+
)
371+
315372

316373
if __name__ == "__main__":
317374
absltest.main(testLoader=jtu.JaxTestLoader())

tests/layers/vllm/test_mxfp4.py

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from jax.sharding import NamedSharding, PartitionSpec
1111
from torchax.interop import torch_view
1212
from torchax.ops.mappings import j2t, t2j
13-
from vllm.config import set_current_vllm_config
13+
from vllm.config import ParallelConfig, set_current_vllm_config
1414
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
1515
init_distributed_environment)
1616
from vllm.engine.arg_utils import EngineArgs
@@ -114,8 +114,8 @@ def test_quant_override(model, mesh):
114114
@pytest.mark.parametrize("hidden_size", [128])
115115
@pytest.mark.parametrize("num_experts", [8])
116116
@pytest.mark.parametrize("topk", [2])
117-
def test_fused_moe_bias(mesh, num_tokens, intermediate_size, hidden_size,
118-
num_experts, topk):
117+
def test_mxfp4_fused_moe(mesh, num_tokens, intermediate_size, hidden_size,
118+
num_experts, topk):
119119
torch.manual_seed(42)
120120
dtype = torch.bfloat16
121121

@@ -192,3 +192,88 @@ def test_fused_moe_bias(mesh, num_tokens, intermediate_size, hidden_size,
192192
atol=0.1)
193193

194194
vllm_fused_moe(jax_a, score)
195+
196+
197+
@pytest.mark.parametrize("mesh", [
198+
test_utils.get_spmd_mesh(1),
199+
test_utils.get_spmd_mesh(jax.local_device_count())
200+
])
201+
@pytest.mark.parametrize("num_tokens", [8])
202+
@pytest.mark.parametrize("intermediate_size", [1024])
203+
@pytest.mark.parametrize("hidden_size", [128])
204+
@pytest.mark.parametrize("num_experts", [8])
205+
@pytest.mark.parametrize("topk", [2])
206+
def test_mxfp4_fused_moe_use_kernel(mesh, num_tokens, intermediate_size,
207+
hidden_size, num_experts, topk):
208+
torch.manual_seed(42)
209+
dtype = torch.bfloat16
210+
211+
a = torch.randn((num_tokens, hidden_size), dtype=dtype) / 10
212+
w1 = torch.randn(
213+
(num_experts, 2 * intermediate_size, hidden_size), dtype=dtype) / 10
214+
w2 = torch.randn(
215+
(num_experts, hidden_size, intermediate_size), dtype=dtype) / 10
216+
w1_weight, w1_weight_scale = quantize_to_mxfp4(w1)
217+
w2_weight, w2_weight_scale = quantize_to_mxfp4(w2)
218+
219+
w1_bias = torch.randn(
220+
(num_experts, 2 * intermediate_size), dtype=dtype) / 10
221+
w2_bias = torch.randn((num_experts, hidden_size), dtype=dtype) / 10
222+
score = torch.randn((num_tokens, num_experts), dtype=dtype)
223+
224+
engine_args = EngineArgs(
225+
model=MODELS[0],
226+
max_model_len=64,
227+
max_num_batched_tokens=64,
228+
max_num_seqs=4,
229+
load_format='dummy',
230+
)
231+
vllm_config = engine_args.create_engine_config()
232+
vllm_config.model_config.dtype = dtype
233+
vllm_config.parallel_config = ParallelConfig(
234+
tensor_parallel_size=mesh.devices.size, enable_expert_paralle=True)
235+
236+
quant_config = get_tpu_quantization_config(vllm_config, mesh)
237+
with set_current_vllm_config(vllm_config):
238+
vllm_fused_moe = FusedMoE(
239+
num_experts=num_experts,
240+
top_k=topk,
241+
hidden_size=hidden_size,
242+
intermediate_size=intermediate_size,
243+
reduce_results=False,
244+
renormalize=False,
245+
tp_size=1,
246+
dp_size=1,
247+
quant_config=quant_config,
248+
has_bias=True,
249+
)
250+
vllm_fused_moe.w13_weight.data = w1_weight
251+
vllm_fused_moe.w2_weight.data = w2_weight
252+
vllm_fused_moe.w13_weight_scale.data = w1_weight_scale
253+
vllm_fused_moe.w2_weight_scale.data = w2_weight_scale
254+
vllm_fused_moe.w13_bias.data = w1_bias
255+
vllm_fused_moe.w2_bias.data = w2_bias
256+
257+
with torchax.default_env(), set_forward_context(None, vllm_config):
258+
assert isinstance(vllm_fused_moe.quant_method, VllmMxfp4MoEMethod)
259+
260+
jax_a = a.to('jax')
261+
jax_a.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None)))
262+
score = torch_view(t2j(score))
263+
score.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None)))
264+
265+
vllm_fused_moe.quant_method.use_kernel = True
266+
vllm_fused_moe.quant_method.process_weights_after_loading(
267+
vllm_fused_moe)
268+
vllm_fused_moe.quant_method.block_size = {
269+
"bt": 32,
270+
"bf": 512,
271+
"bd1": 512,
272+
"bd2": 512,
273+
"btc": 32,
274+
"bfc": 256,
275+
"bd1c": 256,
276+
"bd2c": 256,
277+
}
278+
279+
vllm_fused_moe(jax_a, score)

tests/layers/vllm/test_unquantized.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
import tempfile
32

43
import jax
@@ -416,7 +415,6 @@ def test_merged_column_parallel_linear(model, bias, mesh, fuse_matmuls,
416415
@pytest.mark.parametrize("topk", [2])
417416
def test_fused_moe(use_ep, mesh, num_tokens, intermediate_size, hidden_size,
418417
num_experts, topk):
419-
os.environ['VLLM_DISABLE_SHARED_EXPERTS_STREAM'] = '1'
420418
torch.manual_seed(42)
421419
dtype = torch.bfloat16
422420

@@ -496,7 +494,6 @@ def test_fused_moe(use_ep, mesh, num_tokens, intermediate_size, hidden_size,
496494
@pytest.mark.parametrize("topk", [2])
497495
def test_fused_moe_bias(mesh, num_tokens, intermediate_size, hidden_size,
498496
num_experts, topk):
499-
os.environ['VLLM_DISABLE_SHARED_EXPERTS_STREAM'] = '1'
500497
torch.manual_seed(42)
501498
dtype = torch.bfloat16
502499

@@ -563,7 +560,6 @@ def test_fused_moe_bias(mesh, num_tokens, intermediate_size, hidden_size,
563560
@pytest.mark.parametrize("activation", ["silu", "swigluoai"])
564561
def test_fused_moe_activation(mesh, num_tokens, intermediate_size, hidden_size,
565562
num_experts, topk, activation):
566-
os.environ['VLLM_DISABLE_SHARED_EXPERTS_STREAM'] = '1'
567563
torch.manual_seed(42)
568564
dtype = torch.bfloat16
569565

@@ -613,21 +609,20 @@ def test_fused_moe_activation(mesh, num_tokens, intermediate_size, hidden_size,
613609
vllm_fused_moe(jax_a, score)
614610

615611

616-
@pytest.mark.parametrize("use_ep", [True])
617612
@pytest.mark.parametrize("mesh",
618613
[test_utils.get_spmd_mesh(jax.local_device_count())])
619614
@pytest.mark.parametrize("num_tokens", [128, 512])
620615
@pytest.mark.parametrize("intermediate_size", [256, 512])
621616
@pytest.mark.parametrize("hidden_size", [256])
622617
@pytest.mark.parametrize("num_experts", [32])
623-
@pytest.mark.parametrize("topk", [2])
624-
def test_fused_moe_use_kernel(use_ep, mesh, num_tokens, intermediate_size,
625-
hidden_size, num_experts, topk):
618+
@pytest.mark.parametrize("topk", [8])
619+
@pytest.mark.parametrize("has_bias", [False, True])
620+
def test_fused_moe_use_kernel(mesh, num_tokens, intermediate_size, hidden_size,
621+
num_experts, topk, has_bias):
626622

627623
if jax.local_device_count() < 8:
628624
pytest.skip("Test requires at least 8 devices")
629625

630-
os.environ['VLLM_DISABLE_SHARED_EXPERTS_STREAM'] = '1'
631626
torch.manual_seed(42)
632627
dtype = torch.bfloat16
633628

@@ -636,6 +631,10 @@ def test_fused_moe_use_kernel(use_ep, mesh, num_tokens, intermediate_size,
636631
(num_experts, 2 * intermediate_size, hidden_size), dtype=dtype) / 10
637632
w2 = torch.randn(
638633
(num_experts, hidden_size, intermediate_size), dtype=dtype) / 10
634+
if has_bias:
635+
b1 = torch.randn(
636+
(num_experts, 2 * intermediate_size), dtype=dtype) / 10
637+
b2 = torch.randn((num_experts, hidden_size), dtype=dtype) / 10
639638

640639
# Use deterministic gating_output generation (same logic as fused_moe_v1_test.py)
641640
# Generate base gating scores with deterministic pattern
@@ -679,7 +678,7 @@ def test_fused_moe_use_kernel(use_ep, mesh, num_tokens, intermediate_size,
679678
vllm_config = engine_args.create_engine_config()
680679
vllm_config.model_config.dtype = dtype
681680
vllm_config.parallel_config = ParallelConfig(
682-
tensor_parallel_size=mesh.devices.size, enable_expert_paralle=use_ep)
681+
tensor_parallel_size=mesh.devices.size, enable_expert_paralle=True)
683682

684683
quant_config = get_tpu_quantization_config(vllm_config, mesh)
685684
with set_current_vllm_config(vllm_config):
@@ -693,11 +692,15 @@ def test_fused_moe_use_kernel(use_ep, mesh, num_tokens, intermediate_size,
693692
tp_size=mesh.devices.size,
694693
dp_size=1,
695694
quant_config=quant_config,
695+
has_bias=has_bias,
696696
)
697-
vllm_fused_moe.moe_parallel_config.use_ep = use_ep
697+
vllm_fused_moe.moe_parallel_config.use_ep = True
698698

699699
vllm_fused_moe.w13_weight.data = w1
700700
vllm_fused_moe.w2_weight.data = w2
701+
if has_bias:
702+
vllm_fused_moe.w13_bias.data = b1
703+
vllm_fused_moe.w2_bias.data = b2
701704

702705
p_spec = P('model', )
703706
jax_a = torch_view(t2j(a, use_dlpack=False))

0 commit comments

Comments
 (0)