@@ -1376,171 +1376,166 @@ def fused_ep_moe(
13761376 hbm_block_spec = pl .BlockSpec (memory_space = pltpu .MemorySpace .HBM )
13771377 renorm_str = "-renorm_k" if renormalize_topk_logits else ""
13781378 scope_name = f"fused-moe-k_{ top_k } { renorm_str } -bt_{ bt } _{ btc } -bf_{ bf } _{ bfc } -bd1_{ bd1 } _{ bd1c } -bd2_{ bd2 } _{ bd2c } "
1379- fused_moe = jax .named_scope (scope_name )(
1380- pl .pallas_call (
1381- functools .partial (
1382- _fused_ep_moe_kernel ,
1383- top_k = top_k ,
1384- renormalize_topk_logits = renormalize_topk_logits ,
1385- ep_axis_name = ep_axis_name ,
1386- act_fn = act_fn ,
1387- subc_quant_wsz = subc_quant_wsz ,
1388- bt = bt ,
1389- bf = bf ,
1390- bd1 = bd1 ,
1391- bd2 = bd2 ,
1392- btc = btc ,
1393- bfc = bfc ,
1394- bd1c = bd1c ,
1395- bd2c = bd2c ,
1396- ),
1397- out_shape = jax .ShapeDtypeStruct ((local_num_tokens , hidden_size ),
1398- t_dtype ),
1399- grid_spec = pltpu .PrefetchScalarGridSpec (
1400- num_scalar_prefetch = 0 ,
1401- in_specs = [
1402- hbm_block_spec , # tokens_hbm
1403- hbm_block_spec , # w1_hbm
1404- hbm_block_spec , # w2_hbm
1405- None
1406- if w1_scale is None else hbm_block_spec , # w1_scale_hbm
1407- None
1408- if w2_scale is None else hbm_block_spec , # w2_scale_hbm
1409- None if b1 is None else hbm_block_spec , # b1_hbm
1410- None if b2 is None else hbm_block_spec , # b2_hbm
1411- hbm_block_spec , # gating_output_hbm
1412- hbm_block_spec , # a2a_g_hbm
1413- ],
1414- out_specs = pl .BlockSpec (memory_space = pltpu .MemorySpace .HBM ),
1415- scratch_shapes = ([
1416- # t2e_routing_x2_smem
1417- pltpu .SMEM ((2 , bt , padded_top_k ), jnp .int32 ),
1418- # d2e_count_x2_smem
1419- pltpu .SMEM ((2 , num_devices , 1 , padded_num_experts ),
1420- jnp .int32 ),
1421- # expert_offsets_x2_smem
1422- pltpu .SMEM ((2 , 2 , padded_num_experts ), jnp .int32 ),
1423- # expert_starts_x2_smem
1424- pltpu .SMEM ((2 , 1 , padded_num_experts ), jnp .int32 ),
1425- # expert_sizes_x2_smem
1426- pltpu .SMEM ((2 , 1 , padded_num_experts ), jnp .int32 ),
1427- # a2a_s_sends_x2_smem
1428- pltpu .SMEM ((2 , ), jnp .int32 ),
1429- # a2a_s_x2_vmem
1430- pltpu .VMEM (
1431- (
1432- 2 ,
1433- bt * num_devices ,
1434- t_packing ,
1435- hidden_size // t_packing ,
1436- ),
1437- t_dtype ,
1379+ fused_moe = pl .pallas_call (
1380+ functools .partial (
1381+ _fused_ep_moe_kernel ,
1382+ top_k = top_k ,
1383+ renormalize_topk_logits = renormalize_topk_logits ,
1384+ ep_axis_name = ep_axis_name ,
1385+ act_fn = act_fn ,
1386+ subc_quant_wsz = subc_quant_wsz ,
1387+ bt = bt ,
1388+ bf = bf ,
1389+ bd1 = bd1 ,
1390+ bd2 = bd2 ,
1391+ btc = btc ,
1392+ bfc = bfc ,
1393+ bd1c = bd1c ,
1394+ bd2c = bd2c ,
1395+ ),
1396+ out_shape = jax .ShapeDtypeStruct ((local_num_tokens , hidden_size ),
1397+ t_dtype ),
1398+ grid_spec = pltpu .PrefetchScalarGridSpec (
1399+ num_scalar_prefetch = 0 ,
1400+ in_specs = [
1401+ hbm_block_spec , # tokens_hbm
1402+ hbm_block_spec , # w1_hbm
1403+ hbm_block_spec , # w2_hbm
1404+ None if w1_scale is None else hbm_block_spec , # w1_scale_hbm
1405+ None if w2_scale is None else hbm_block_spec , # w2_scale_hbm
1406+ None if b1 is None else hbm_block_spec , # b1_hbm
1407+ None if b2 is None else hbm_block_spec , # b2_hbm
1408+ hbm_block_spec , # gating_output_hbm
1409+ hbm_block_spec , # a2a_g_hbm
1410+ ],
1411+ out_specs = pl .BlockSpec (memory_space = pltpu .MemorySpace .HBM ),
1412+ scratch_shapes = ([
1413+ # t2e_routing_x2_smem
1414+ pltpu .SMEM ((2 , bt , padded_top_k ), jnp .int32 ),
1415+ # d2e_count_x2_smem
1416+ pltpu .SMEM ((2 , num_devices , 1 , padded_num_experts ), jnp .int32 ),
1417+ # expert_offsets_x2_smem
1418+ pltpu .SMEM ((2 , 2 , padded_num_experts ), jnp .int32 ),
1419+ # expert_starts_x2_smem
1420+ pltpu .SMEM ((2 , 1 , padded_num_experts ), jnp .int32 ),
1421+ # expert_sizes_x2_smem
1422+ pltpu .SMEM ((2 , 1 , padded_num_experts ), jnp .int32 ),
1423+ # a2a_s_sends_x2_smem
1424+ pltpu .SMEM ((2 , ), jnp .int32 ),
1425+ # a2a_s_x2_vmem
1426+ pltpu .VMEM (
1427+ (
1428+ 2 ,
1429+ bt * num_devices ,
1430+ t_packing ,
1431+ hidden_size // t_packing ,
14381432 ),
1439- # a2a_s_acc_x2_vmem
1440- pltpu . VMEM (
1441- (
1442- 2 ,
1443- bt * num_devices ,
1444- t_packing ,
1445- hidden_size // t_packing ,
1446- ) ,
1447- t_dtype ,
1433+ t_dtype ,
1434+ ),
1435+ # a2a_s_acc_x2_vmem
1436+ pltpu . VMEM (
1437+ (
1438+ 2 ,
1439+ bt * num_devices ,
1440+ t_packing ,
1441+ hidden_size // t_packing ,
14481442 ),
1449- # a2a_g_acc_vmem
1450- pltpu .VMEM (
1451- (top_k , bt , t_packing , hidden_size // t_packing ),
1452- t_dtype ),
1453- # b_gating_x2_vmem
1454- pltpu .VMEM ((2 , bt , padded_num_experts ), t_dtype ),
1455- # b_output_x2_vmem
1456- pltpu .VMEM ((2 , bt , hidden_size ), t_dtype ),
1457- # b_w1_x2_vmem
1458- pltpu .VMEM ((2 , t_packing , bd1 // t_packing , bf ), w1 .dtype ),
1459- # b_w3_x2_vmem
1460- pltpu .VMEM ((2 , t_packing , bd1 // t_packing , bf ), w1 .dtype ),
1461- # b_w2_x2_vmem
1462- pltpu .VMEM ((2 , t_packing , bf , bd2 // t_packing ), w2 .dtype ),
1463- # b_w1_scale_x2_vmem
1464- (None if w1_scale is None else pltpu .VMEM (
1465- (
1466- 2 ,
1467- t_packing ,
1468- bd1 // t_packing // subc_quant_wsz ,
1469- 1 ,
1470- bf ,
1471- ),
1472- jnp .float32 ,
1473- )),
1474- # b_w3_scale_x2_vmem
1475- (None if w1_scale is None else pltpu .VMEM (
1476- (
1477- 2 ,
1478- t_packing ,
1479- bd1 // t_packing // subc_quant_wsz ,
1480- 1 ,
1481- bf ,
1482- ),
1483- jnp .float32 ,
1484- )),
1485- # b_w2_scale_x2_vmem
1486- (None if w2_scale is None else pltpu .VMEM (
1487- (
1488- 2 ,
1489- t_packing ,
1490- bf // subc_quant_wsz ,
1491- 1 ,
1492- bd2 // t_packing ,
1493- ),
1494- jnp .float32 ,
1495- )),
1496- # b_b1_x2_vmem
1497- (None if b1 is None else pltpu .VMEM (
1498- (
1499- 2 ,
1500- 1 ,
1501- bf ,
1502- ),
1503- jnp .float32 ,
1504- )),
1505- # b_b3_x2_vmem
1506- (None if b1 is None else pltpu .VMEM (
1507- (
1508- 2 ,
1509- 1 ,
1510- bf ,
1511- ),
1512- jnp .float32 ,
1513- )),
1514- # b_b2_x2_vmem
1515- (None if b2 is None else pltpu .VMEM (
1516- (
1517- 2 ,
1518- t_packing ,
1519- 1 ,
1520- bd2 // t_packing ,
1521- ),
1522- jnp .float32 ,
1523- )),
1524- # b_acc_vmem
1525- pltpu .VMEM ((bt * num_devices , 1 , bf * 2 ), jnp .float32 ),
1526- # local_sems
1527- pltpu .SemaphoreType .DMA ((2 , 5 )),
1528- # send_sems
1529- pltpu .SemaphoreType .DMA ((2 , )),
1530- # recv_sems
1531- pltpu .SemaphoreType .DMA ((2 , )),
1532- # a2a_gather_sem
1533- pltpu .SemaphoreType .DMA ,
1534- # a2a_acc_sem
1535- pltpu .SemaphoreType .DMA ,
1536- ]),
1537- ),
1538- compiler_params = pltpu .CompilerParams (
1539- collective_id = 0 ,
1540- vmem_limit_bytes = 100 * 1024 * 1024 ,
1541- ),
1542- name = scope_name ,
1543- ))
1443+ t_dtype ,
1444+ ),
1445+ # a2a_g_acc_vmem
1446+ pltpu .VMEM ((top_k , bt , t_packing , hidden_size // t_packing ),
1447+ t_dtype ),
1448+ # b_gating_x2_vmem
1449+ pltpu .VMEM ((2 , bt , padded_num_experts ), t_dtype ),
1450+ # b_output_x2_vmem
1451+ pltpu .VMEM ((2 , bt , hidden_size ), t_dtype ),
1452+ # b_w1_x2_vmem
1453+ pltpu .VMEM ((2 , t_packing , bd1 // t_packing , bf ), w1 .dtype ),
1454+ # b_w3_x2_vmem
1455+ pltpu .VMEM ((2 , t_packing , bd1 // t_packing , bf ), w1 .dtype ),
1456+ # b_w2_x2_vmem
1457+ pltpu .VMEM ((2 , t_packing , bf , bd2 // t_packing ), w2 .dtype ),
1458+ # b_w1_scale_x2_vmem
1459+ (None if w1_scale is None else pltpu .VMEM (
1460+ (
1461+ 2 ,
1462+ t_packing ,
1463+ bd1 // t_packing // subc_quant_wsz ,
1464+ 1 ,
1465+ bf ,
1466+ ),
1467+ jnp .float32 ,
1468+ )),
1469+ # b_w3_scale_x2_vmem
1470+ (None if w1_scale is None else pltpu .VMEM (
1471+ (
1472+ 2 ,
1473+ t_packing ,
1474+ bd1 // t_packing // subc_quant_wsz ,
1475+ 1 ,
1476+ bf ,
1477+ ),
1478+ jnp .float32 ,
1479+ )),
1480+ # b_w2_scale_x2_vmem
1481+ (None if w2_scale is None else pltpu .VMEM (
1482+ (
1483+ 2 ,
1484+ t_packing ,
1485+ bf // subc_quant_wsz ,
1486+ 1 ,
1487+ bd2 // t_packing ,
1488+ ),
1489+ jnp .float32 ,
1490+ )),
1491+ # b_b1_x2_vmem
1492+ (None if b1 is None else pltpu .VMEM (
1493+ (
1494+ 2 ,
1495+ 1 ,
1496+ bf ,
1497+ ),
1498+ jnp .float32 ,
1499+ )),
1500+ # b_b3_x2_vmem
1501+ (None if b1 is None else pltpu .VMEM (
1502+ (
1503+ 2 ,
1504+ 1 ,
1505+ bf ,
1506+ ),
1507+ jnp .float32 ,
1508+ )),
1509+ # b_b2_x2_vmem
1510+ (None if b2 is None else pltpu .VMEM (
1511+ (
1512+ 2 ,
1513+ t_packing ,
1514+ 1 ,
1515+ bd2 // t_packing ,
1516+ ),
1517+ jnp .float32 ,
1518+ )),
1519+ # b_acc_vmem
1520+ pltpu .VMEM ((bt * num_devices , 1 , bf * 2 ), jnp .float32 ),
1521+ # local_sems
1522+ pltpu .SemaphoreType .DMA ((2 , 5 )),
1523+ # send_sems
1524+ pltpu .SemaphoreType .DMA ((2 , )),
1525+ # recv_sems
1526+ pltpu .SemaphoreType .DMA ((2 , )),
1527+ # a2a_gather_sem
1528+ pltpu .SemaphoreType .DMA ,
1529+ # a2a_acc_sem
1530+ pltpu .SemaphoreType .DMA ,
1531+ ]),
1532+ ),
1533+ compiler_params = pltpu .CompilerParams (
1534+ collective_id = 0 ,
1535+ vmem_limit_bytes = 100 * 1024 * 1024 ,
1536+ ),
1537+ name = scope_name ,
1538+ )
15441539
15451540 @jax .jit
15461541 @jax .shard_map (
0 commit comments