Skip to content

Commit 500ebcf

Browse files
committed
[Kernel][Misc] Remove jax.named_scope
Signed-off-by: Kyuyeun Kim <[email protected]>
1 parent 098e3d1 commit 500ebcf

File tree

3 files changed

+235
-244
lines changed

3 files changed

+235
-244
lines changed

tpu_inference/kernels/fused_moe/v1/kernel.py

Lines changed: 158 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -1376,171 +1376,166 @@ def fused_ep_moe(
13761376
hbm_block_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM)
13771377
renorm_str = "-renorm_k" if renormalize_topk_logits else ""
13781378
scope_name = f"fused-moe-k_{top_k}{renorm_str}-bt_{bt}_{btc}-bf_{bf}_{bfc}-bd1_{bd1}_{bd1c}-bd2_{bd2}_{bd2c}"
1379-
fused_moe = jax.named_scope(scope_name)(
1380-
pl.pallas_call(
1381-
functools.partial(
1382-
_fused_ep_moe_kernel,
1383-
top_k=top_k,
1384-
renormalize_topk_logits=renormalize_topk_logits,
1385-
ep_axis_name=ep_axis_name,
1386-
act_fn=act_fn,
1387-
subc_quant_wsz=subc_quant_wsz,
1388-
bt=bt,
1389-
bf=bf,
1390-
bd1=bd1,
1391-
bd2=bd2,
1392-
btc=btc,
1393-
bfc=bfc,
1394-
bd1c=bd1c,
1395-
bd2c=bd2c,
1396-
),
1397-
out_shape=jax.ShapeDtypeStruct((local_num_tokens, hidden_size),
1398-
t_dtype),
1399-
grid_spec=pltpu.PrefetchScalarGridSpec(
1400-
num_scalar_prefetch=0,
1401-
in_specs=[
1402-
hbm_block_spec, # tokens_hbm
1403-
hbm_block_spec, # w1_hbm
1404-
hbm_block_spec, # w2_hbm
1405-
None
1406-
if w1_scale is None else hbm_block_spec, # w1_scale_hbm
1407-
None
1408-
if w2_scale is None else hbm_block_spec, # w2_scale_hbm
1409-
None if b1 is None else hbm_block_spec, # b1_hbm
1410-
None if b2 is None else hbm_block_spec, # b2_hbm
1411-
hbm_block_spec, # gating_output_hbm
1412-
hbm_block_spec, # a2a_g_hbm
1413-
],
1414-
out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
1415-
scratch_shapes=([
1416-
# t2e_routing_x2_smem
1417-
pltpu.SMEM((2, bt, padded_top_k), jnp.int32),
1418-
# d2e_count_x2_smem
1419-
pltpu.SMEM((2, num_devices, 1, padded_num_experts),
1420-
jnp.int32),
1421-
# expert_offsets_x2_smem
1422-
pltpu.SMEM((2, 2, padded_num_experts), jnp.int32),
1423-
# expert_starts_x2_smem
1424-
pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
1425-
# expert_sizes_x2_smem
1426-
pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
1427-
# a2a_s_sends_x2_smem
1428-
pltpu.SMEM((2, ), jnp.int32),
1429-
# a2a_s_x2_vmem
1430-
pltpu.VMEM(
1431-
(
1432-
2,
1433-
bt * num_devices,
1434-
t_packing,
1435-
hidden_size // t_packing,
1436-
),
1437-
t_dtype,
1379+
fused_moe = pl.pallas_call(
1380+
functools.partial(
1381+
_fused_ep_moe_kernel,
1382+
top_k=top_k,
1383+
renormalize_topk_logits=renormalize_topk_logits,
1384+
ep_axis_name=ep_axis_name,
1385+
act_fn=act_fn,
1386+
subc_quant_wsz=subc_quant_wsz,
1387+
bt=bt,
1388+
bf=bf,
1389+
bd1=bd1,
1390+
bd2=bd2,
1391+
btc=btc,
1392+
bfc=bfc,
1393+
bd1c=bd1c,
1394+
bd2c=bd2c,
1395+
),
1396+
out_shape=jax.ShapeDtypeStruct((local_num_tokens, hidden_size),
1397+
t_dtype),
1398+
grid_spec=pltpu.PrefetchScalarGridSpec(
1399+
num_scalar_prefetch=0,
1400+
in_specs=[
1401+
hbm_block_spec, # tokens_hbm
1402+
hbm_block_spec, # w1_hbm
1403+
hbm_block_spec, # w2_hbm
1404+
None if w1_scale is None else hbm_block_spec, # w1_scale_hbm
1405+
None if w2_scale is None else hbm_block_spec, # w2_scale_hbm
1406+
None if b1 is None else hbm_block_spec, # b1_hbm
1407+
None if b2 is None else hbm_block_spec, # b2_hbm
1408+
hbm_block_spec, # gating_output_hbm
1409+
hbm_block_spec, # a2a_g_hbm
1410+
],
1411+
out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
1412+
scratch_shapes=([
1413+
# t2e_routing_x2_smem
1414+
pltpu.SMEM((2, bt, padded_top_k), jnp.int32),
1415+
# d2e_count_x2_smem
1416+
pltpu.SMEM((2, num_devices, 1, padded_num_experts), jnp.int32),
1417+
# expert_offsets_x2_smem
1418+
pltpu.SMEM((2, 2, padded_num_experts), jnp.int32),
1419+
# expert_starts_x2_smem
1420+
pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
1421+
# expert_sizes_x2_smem
1422+
pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
1423+
# a2a_s_sends_x2_smem
1424+
pltpu.SMEM((2, ), jnp.int32),
1425+
# a2a_s_x2_vmem
1426+
pltpu.VMEM(
1427+
(
1428+
2,
1429+
bt * num_devices,
1430+
t_packing,
1431+
hidden_size // t_packing,
14381432
),
1439-
# a2a_s_acc_x2_vmem
1440-
pltpu.VMEM(
1441-
(
1442-
2,
1443-
bt * num_devices,
1444-
t_packing,
1445-
hidden_size // t_packing,
1446-
),
1447-
t_dtype,
1433+
t_dtype,
1434+
),
1435+
# a2a_s_acc_x2_vmem
1436+
pltpu.VMEM(
1437+
(
1438+
2,
1439+
bt * num_devices,
1440+
t_packing,
1441+
hidden_size // t_packing,
14481442
),
1449-
# a2a_g_acc_vmem
1450-
pltpu.VMEM(
1451-
(top_k, bt, t_packing, hidden_size // t_packing),
1452-
t_dtype),
1453-
# b_gating_x2_vmem
1454-
pltpu.VMEM((2, bt, padded_num_experts), t_dtype),
1455-
# b_output_x2_vmem
1456-
pltpu.VMEM((2, bt, hidden_size), t_dtype),
1457-
# b_w1_x2_vmem
1458-
pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
1459-
# b_w3_x2_vmem
1460-
pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
1461-
# b_w2_x2_vmem
1462-
pltpu.VMEM((2, t_packing, bf, bd2 // t_packing), w2.dtype),
1463-
# b_w1_scale_x2_vmem
1464-
(None if w1_scale is None else pltpu.VMEM(
1465-
(
1466-
2,
1467-
t_packing,
1468-
bd1 // t_packing // subc_quant_wsz,
1469-
1,
1470-
bf,
1471-
),
1472-
jnp.float32,
1473-
)),
1474-
# b_w3_scale_x2_vmem
1475-
(None if w1_scale is None else pltpu.VMEM(
1476-
(
1477-
2,
1478-
t_packing,
1479-
bd1 // t_packing // subc_quant_wsz,
1480-
1,
1481-
bf,
1482-
),
1483-
jnp.float32,
1484-
)),
1485-
# b_w2_scale_x2_vmem
1486-
(None if w2_scale is None else pltpu.VMEM(
1487-
(
1488-
2,
1489-
t_packing,
1490-
bf // subc_quant_wsz,
1491-
1,
1492-
bd2 // t_packing,
1493-
),
1494-
jnp.float32,
1495-
)),
1496-
# b_b1_x2_vmem
1497-
(None if b1 is None else pltpu.VMEM(
1498-
(
1499-
2,
1500-
1,
1501-
bf,
1502-
),
1503-
jnp.float32,
1504-
)),
1505-
# b_b3_x2_vmem
1506-
(None if b1 is None else pltpu.VMEM(
1507-
(
1508-
2,
1509-
1,
1510-
bf,
1511-
),
1512-
jnp.float32,
1513-
)),
1514-
# b_b2_x2_vmem
1515-
(None if b2 is None else pltpu.VMEM(
1516-
(
1517-
2,
1518-
t_packing,
1519-
1,
1520-
bd2 // t_packing,
1521-
),
1522-
jnp.float32,
1523-
)),
1524-
# b_acc_vmem
1525-
pltpu.VMEM((bt * num_devices, 1, bf * 2), jnp.float32),
1526-
# local_sems
1527-
pltpu.SemaphoreType.DMA((2, 5)),
1528-
# send_sems
1529-
pltpu.SemaphoreType.DMA((2, )),
1530-
# recv_sems
1531-
pltpu.SemaphoreType.DMA((2, )),
1532-
# a2a_gather_sem
1533-
pltpu.SemaphoreType.DMA,
1534-
# a2a_acc_sem
1535-
pltpu.SemaphoreType.DMA,
1536-
]),
1537-
),
1538-
compiler_params=pltpu.CompilerParams(
1539-
collective_id=0,
1540-
vmem_limit_bytes=100 * 1024 * 1024,
1541-
),
1542-
name=scope_name,
1543-
))
1443+
t_dtype,
1444+
),
1445+
# a2a_g_acc_vmem
1446+
pltpu.VMEM((top_k, bt, t_packing, hidden_size // t_packing),
1447+
t_dtype),
1448+
# b_gating_x2_vmem
1449+
pltpu.VMEM((2, bt, padded_num_experts), t_dtype),
1450+
# b_output_x2_vmem
1451+
pltpu.VMEM((2, bt, hidden_size), t_dtype),
1452+
# b_w1_x2_vmem
1453+
pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
1454+
# b_w3_x2_vmem
1455+
pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
1456+
# b_w2_x2_vmem
1457+
pltpu.VMEM((2, t_packing, bf, bd2 // t_packing), w2.dtype),
1458+
# b_w1_scale_x2_vmem
1459+
(None if w1_scale is None else pltpu.VMEM(
1460+
(
1461+
2,
1462+
t_packing,
1463+
bd1 // t_packing // subc_quant_wsz,
1464+
1,
1465+
bf,
1466+
),
1467+
jnp.float32,
1468+
)),
1469+
# b_w3_scale_x2_vmem
1470+
(None if w1_scale is None else pltpu.VMEM(
1471+
(
1472+
2,
1473+
t_packing,
1474+
bd1 // t_packing // subc_quant_wsz,
1475+
1,
1476+
bf,
1477+
),
1478+
jnp.float32,
1479+
)),
1480+
# b_w2_scale_x2_vmem
1481+
(None if w2_scale is None else pltpu.VMEM(
1482+
(
1483+
2,
1484+
t_packing,
1485+
bf // subc_quant_wsz,
1486+
1,
1487+
bd2 // t_packing,
1488+
),
1489+
jnp.float32,
1490+
)),
1491+
# b_b1_x2_vmem
1492+
(None if b1 is None else pltpu.VMEM(
1493+
(
1494+
2,
1495+
1,
1496+
bf,
1497+
),
1498+
jnp.float32,
1499+
)),
1500+
# b_b3_x2_vmem
1501+
(None if b1 is None else pltpu.VMEM(
1502+
(
1503+
2,
1504+
1,
1505+
bf,
1506+
),
1507+
jnp.float32,
1508+
)),
1509+
# b_b2_x2_vmem
1510+
(None if b2 is None else pltpu.VMEM(
1511+
(
1512+
2,
1513+
t_packing,
1514+
1,
1515+
bd2 // t_packing,
1516+
),
1517+
jnp.float32,
1518+
)),
1519+
# b_acc_vmem
1520+
pltpu.VMEM((bt * num_devices, 1, bf * 2), jnp.float32),
1521+
# local_sems
1522+
pltpu.SemaphoreType.DMA((2, 5)),
1523+
# send_sems
1524+
pltpu.SemaphoreType.DMA((2, )),
1525+
# recv_sems
1526+
pltpu.SemaphoreType.DMA((2, )),
1527+
# a2a_gather_sem
1528+
pltpu.SemaphoreType.DMA,
1529+
# a2a_acc_sem
1530+
pltpu.SemaphoreType.DMA,
1531+
]),
1532+
),
1533+
compiler_params=pltpu.CompilerParams(
1534+
collective_id=0,
1535+
vmem_limit_bytes=100 * 1024 * 1024,
1536+
),
1537+
name=scope_name,
1538+
)
15441539

15451540
@jax.jit
15461541
@jax.shard_map(

0 commit comments

Comments
 (0)