From d04832126bdbfae8a700fa13e3c3ff0a9ce2b1f4 Mon Sep 17 00:00:00 2001 From: "Sun, Diwei" Date: Tue, 19 Aug 2025 07:35:07 +0000 Subject: [PATCH 1/4] xpu ut enablinng: moe_training --- test/prototype/moe_training/test_kernels.py | 10 +++++-- .../moe_training/test_scaled_grouped_mm.py | 30 +++++++++++-------- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/test/prototype/moe_training/test_kernels.py b/test/prototype/moe_training/test_kernels.py index a10f41e696..d5ae767af5 100644 --- a/test/prototype/moe_training/test_kernels.py +++ b/test/prototype/moe_training/test_kernels.py @@ -26,15 +26,19 @@ torch_to_float8_per_group_colwise, torch_to_float8_per_group_rowwise, ) -from torchao.testing.utils import skip_if_rocm +from torchao.testing.utils import( + skip_if_rocm, +) +from torchao.utils import auto_detect_device +_DEVICE = auto_detect_device() @skip_if_rocm("ROCm enablement in progress") @pytest.mark.parametrize("round_scales_to_power_of_2", [True, False]) def test_row_major_with_jagged_rowwise_scales(round_scales_to_power_of_2: bool): # tests case where rowwise scales are computed for multiple distinct subtensors, # with end boundary of each group is determine by their end column indexes (offsets). - device = "cuda" + device = _DEVICE m, k, n_groups = 256, 256, 4 x = torch.randn(m, k * n_groups, device=device) colwise_offs = torch.arange(k, k * n_groups + 1, k, device=device) @@ -62,7 +66,7 @@ def test_row_major_with_jagged_rowwise_scales(round_scales_to_power_of_2: bool): def test_column_major_with_jagged_colwise_scales(round_scales_to_power_of_2: bool): # tests case where colwise scales are computed for multiple distinct subtensors, # with end boundary of each group is determine by their end row indexes (offsets). - device = "cuda" + device = _DEVICE m, k, n_groups = 256, 256, 4 x = torch.randn(m * n_groups, k, device=device).t().contiguous().t() rowwise_offs = torch.arange(m, m * n_groups + 1, m, device=device) diff --git a/test/prototype/moe_training/test_scaled_grouped_mm.py b/test/prototype/moe_training/test_scaled_grouped_mm.py index 4b76b29a27..f0e43ee600 100644 --- a/test/prototype/moe_training/test_scaled_grouped_mm.py +++ b/test/prototype/moe_training/test_scaled_grouped_mm.py @@ -14,8 +14,6 @@ # triton won't be available on CPU builds and torch < 2.5 if not ( TORCH_VERSION_AT_LEAST_2_7 - and torch.cuda.is_available() - and torch.cuda.get_device_capability()[0] >= 9 ): pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -39,13 +37,16 @@ generate_jagged_offs, ) from torchao.prototype.mx_formats.mx_tensor import to_mx -from torchao.testing.utils import skip_if_rocm +from torchao.testing.utils import skip_if_rocm, skip_if_xpu +from torchao.utils import auto_detect_device +_DEVICE = auto_detect_device() @skip_if_rocm("ROCm not supported") +@skip_if_xpu("XPU not supported") def test_valid_scaled_grouped_mm_2d_3d(): out_dtype = torch.bfloat16 - device = "cuda" + device = _DEVICE m, n, k, n_groups = 16, 32, 16, 4 a = torch.randn( m * n_groups, @@ -61,7 +62,7 @@ def test_valid_scaled_grouped_mm_2d_3d(): device=device, dtype=torch.bfloat16, ) - offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32) + offs = torch.arange(m, n_groups * m + 1, m, device=_DEVICE, dtype=torch.int32) # b must be transposed and in column major format. b_t = b.contiguous().transpose(-2, -1).requires_grad_(True) @@ -109,7 +110,7 @@ def test_K_or_N_dim_not_multiple_of_16(m, n, k): if n % 16 == 0 and k % 16 == 0: return out_dtype = torch.bfloat16 - device = "cuda" + device = _DEVICE n_groups = 4 a = torch.randn( m * n_groups, @@ -131,7 +132,7 @@ def test_K_or_N_dim_not_multiple_of_16(m, n, k): b_t = b.transpose(-2, -1) b_t = b_t.transpose(-2, -1).contiguous().transpose(-2, -1) - offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32) + offs = torch.arange(m, n_groups * m + 1, m, device=_DEVICE, dtype=torch.int32) # Compute output. with pytest.raises(AssertionError): @@ -226,11 +227,12 @@ def compute_reference_forward( @skip_if_rocm("ROCm not supported") +@skip_if_xpu("XPU not supported") @pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)]) @pytest.mark.parametrize("num_experts", (1, 8, 16)) def test_emulate_mxfp8_grouped_gemm_2d_3d(M, K, N, num_experts): - x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") - w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device="cuda") + x = torch.randn(M, K, dtype=torch.bfloat16, device=_DEVICE) + w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device=_DEVICE) offs = generate_jagged_offs(num_experts, M) x_ref, w_t_ref, offs_ref = x.clone(), w_t.clone(), offs.clone() @@ -257,15 +259,16 @@ def test_emulate_mxfp8_grouped_gemm_2d_3d(M, K, N, num_experts): @skip_if_rocm("ROCm not supported") +@skip_if_xpu("XPU not supported") @pytest.mark.parametrize("M", (1024, 4096)) @pytest.mark.parametrize("N", (1024, 4096)) @pytest.mark.parametrize("num_experts", (8, 16)) def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts): # Simluate 2d-2d grouped gemm grad_weight = grad_output_t @ x block_size = 32 - grad_out = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") + grad_out = torch.randn(M, N, dtype=torch.bfloat16, device=_DEVICE) grad_out_t = grad_out.t().contiguous() - x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") + x = torch.randn(M, N, dtype=torch.bfloat16, device=_DEVICE) offs = generate_jagged_offs(num_experts, M, multiple_of=block_size) x_ref, grad_out_t_ref, offs_ref = x.clone(), grad_out_t.clone(), offs.clone() @@ -305,6 +308,7 @@ def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts): @skip_if_rocm("ROCm not supported") +@skip_if_xpu("XPU not supported") @pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)]) @pytest.mark.parametrize("num_experts", (1, 8, 16)) def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts): @@ -313,9 +317,9 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts): ) block_size = 32 - x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True) + x = torch.randn(M, K, dtype=torch.bfloat16, device=_DEVICE, requires_grad=True) w_t = torch.randn( - num_experts, K, N, dtype=torch.bfloat16, device="cuda", requires_grad=True + num_experts, K, N, dtype=torch.bfloat16, device=_DEVICE, requires_grad=True ) offs = generate_jagged_offs(num_experts, M, multiple_of=block_size) x_ref, w_t_ref, offs_ref = ( From 7edb7b18c9f67a500397087bf70b5b3fa4a8b18d Mon Sep 17 00:00:00 2001 From: "Sun, Diwei" Date: Tue, 19 Aug 2025 07:54:05 +0000 Subject: [PATCH 2/4] xpu ut enabling: mx_formats/test_kernels --- test/prototype/mx_formats/test_kernels.py | 60 +++++++++++------------ 1 file changed, 28 insertions(+), 32 deletions(-) diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index 6b0aab129c..58f286ef05 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -45,11 +45,18 @@ from torchao.prototype.mx_formats.mx_tensor import MXTensor, ScaleCalculationMode, to_mx from torchao.prototype.mx_formats.utils import to_blocked from torchao.utils import ( + get_available_devices, TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_89, is_sm_at_least_100, ) +from torchao.testing.utils import skip_if_xpu + +from torchao.utils import get_available_devices + + +_DEVICES = get_available_devices() torch.manual_seed(0) if not TORCH_VERSION_AT_LEAST_2_8: @@ -327,9 +334,9 @@ def test_fp4_pack_unpack(): assert torch.all(orig_vals_dq == orig_vals) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") @pytest.mark.skipif(is_sm_at_least_100(), reason="broken on CUDA capability 10.0") +@skip_if_xpu("XPU not Support") def test_fp4_triton_unscaled_cast(): packed_vals = torch.arange(0, 255, dtype=torch.uint8, device="cuda") f32_ref = f4_unpacked_to_f32(unpack_uint4(packed_vals)) @@ -337,9 +344,9 @@ def test_fp4_triton_unscaled_cast(): assert torch.all(torch.eq(f32_ref, f32_triton)) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") @pytest.mark.skipif(is_sm_at_least_100(), reason="broken on CUDA capability 10.0") +@skip_if_xpu("XPU not Support") def test_fp4_triton_scaled_cast(): size = (256,) orig_vals = torch.randn(size, dtype=torch.float, device="cuda") * 100 @@ -357,7 +364,7 @@ def test_fp4_triton_scaled_cast(): f32_triton = mxtensor_triton.to_dtype(torch.float) assert torch.all(torch.eq(f32_ref, f32_triton)) - +@skip_if_xpu("XPU not Support") @pytest.mark.parametrize("dtype_name", (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2)) def test_fp6_values(dtype_name): """ @@ -403,18 +410,8 @@ def test_fp6_values(dtype_name): torch.testing.assert_close(f32, f32_ref, rtol=0, atol=0) -@pytest.mark.parametrize( - "device", - [ - "cpu", - pytest.param( - "cuda", - marks=pytest.mark.skipif( - not torch.cuda.is_available(), reason="CUDA not available" - ), - ), - ], -) +@skip_if_xpu("XPU not Support") +@pytest.mark.parametrize("device", _DEVICES) @pytest.mark.parametrize( "f32_val,f6_e3m2_enc", [ @@ -433,12 +430,11 @@ def test_fp6_e3m2_rounding(f32_val, f6_e3m2_enc, device): assert f6_e3m2_unpacked.item() == (f6_e3m2_enc | 0b100000) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@skip_if_xpu("XPU not Support") +@pytest.mark.parametrize("device", _DEVICES) @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") -def test_fp6_e2m3_pack_unpack(): - orig_vals = torch.Tensor([[0.0, 0.5, 7.5, -0.0], [-0.875, 1.0, -6.0, 0.125]]).to( - "cuda" - ) +def test_fp6_e2m3_pack_unpack(device): + orig_vals = torch.Tensor([[0.0, 0.5, 7.5, -0.0], [-0.875, 1.0, -6.0, 0.125]]).to(device) orig_vals_f6_unpacked = f32_to_f6_e2m3_unpacked(orig_vals) orig_vals_f6_packed = pack_uint6(orig_vals_f6_unpacked) assert orig_vals_f6_packed.numel() == (3 * orig_vals.numel() // 4) @@ -448,12 +444,11 @@ def test_fp6_e2m3_pack_unpack(): assert torch.all(orig_vals_f6_packed_unpacked == orig_vals) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@skip_if_xpu("XPU not Support") +@pytest.mark.parametrize("device", _DEVICES) @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") -def test_fp6_e3m2_pack_unpack(): - orig_vals = torch.Tensor([[0.0, 5.0, 28.0, -0.0], [-0.25, 0.1875, 0.0625, 8.0]]).to( - "cuda" - ) +def test_fp6_e3m2_pack_unpack(device): + orig_vals = torch.Tensor([[0.0, 5.0, 28.0, -0.0], [-0.25, 0.1875, 0.0625, 8.0]]).to(device) orig_vals_f6_unpacked = f32_to_f6_e3m2_unpacked(orig_vals) orig_vals_f6_packed = pack_uint6(orig_vals_f6_unpacked) assert orig_vals_f6_packed.numel() == (3 * orig_vals.numel() // 4) @@ -471,14 +466,15 @@ def test_fp6_e3m2_pack_unpack(): @pytest.mark.parametrize("M", (256, 2048)) @pytest.mark.parametrize("K", (256, 2048)) def test_triton_mxfp8_dim1_randn(M, K): - x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + x = torch.randn(M, K, dtype=torch.bfloat16, device=device) x_mx_ref, x_s_ref = triton_to_mxfp8_dim1_reference(x, block_size=32) x_mx_t, x_s_t = triton_to_mxfp8_dim1(x, inner_block_size=32) torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0) torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@skip_if_xpu("XPU not Support") +@pytest.mark.parametrize("device", _DEVICES) @pytest.mark.parametrize( "shape", [ @@ -492,8 +488,8 @@ def test_triton_mxfp8_dim1_randn(M, K): (128, 1), ], ) -def test_rearrange(shape): - scales = torch.randint(256, size=shape, device="cuda", dtype=torch.uint8) +def test_rearrange(device, shape): + scales = torch.randint(256, size=shape, device=device, dtype=torch.uint8) eager = to_blocked(scales, False) triton = to_blocked(scales, True) torch.testing.assert_close(eager, triton, atol=0, rtol=0) @@ -519,7 +515,7 @@ def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode): # Use disinct incrementing values from 0 to M*K-1 to make debugging easier. x = ( - torch.arange(0, M * K, dtype=input_dtype, device="cuda") + torch.arange(0, M * K, dtype=input_dtype, device=device) .reshape(M, K) .contiguous() ) @@ -557,7 +553,7 @@ def test_cuda_mx_dim0_not_supported(): M, K = 64, 64 block_size = 32 x = ( - torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda") + torch.arange(0, M * K, dtype=torch.bfloat16, device=device) .reshape(M, K) .contiguous() ) @@ -580,7 +576,7 @@ def test_cuda_mx_dim1_invalid_block_size(): M, K = 64, 64 x = ( - torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda") + torch.arange(0, M * K, dtype=torch.bfloat16, device=device) .reshape(M, K) .contiguous() ) From 91a366c7ac7e29b390bb204b561d20ad5a1ad3f3 Mon Sep 17 00:00:00 2001 From: "Sun, Diwei" Date: Tue, 19 Aug 2025 07:54:18 +0000 Subject: [PATCH 3/4] xpu ut enabling: test_mx_linear --- test/prototype/mx_formats/test_mx_linear.py | 39 +++++++++++---------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index edce5cc7e7..08e8509b13 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -31,6 +31,12 @@ is_sm_at_least_100, ) +from torchao.testing.utils import skip_if_xpu + +from torchao.utils import auto_detect_device + +_DEVICE = auto_detect_device() + torch.manual_seed(2) if not TORCH_VERSION_AT_LEAST_2_8: @@ -67,7 +73,6 @@ def run_around_tests(): ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", elem_dtypes) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("input_shape", [(128, 256), (1, 128, 256), (1, 1, 128, 256)]) @@ -121,7 +126,7 @@ def test_linear_eager_vs_hp( grad_shape[-1] = 256 m = nn.Sequential( - nn.Linear(256, 256, bias=bias, device="cuda", dtype=torch.bfloat16), + nn.Linear(256, 256, bias=bias, device=_DEVICE, dtype=torch.bfloat16), ) m_mx = copy.deepcopy(m) config = MXLinearConfig( @@ -135,10 +140,10 @@ def test_linear_eager_vs_hp( quantize_(m_mx, config) x_ref = torch.randn( - *input_shape, device="cuda", dtype=torch.bfloat16 + *input_shape, device=_DEVICE, dtype=torch.bfloat16 ).requires_grad_() x = copy.deepcopy(x_ref) - g = torch.randn(*grad_shape, device="cuda") + g = torch.randn(*grad_shape, device=_DEVICE) y_ref = m(x_ref) y_mx = m_mx(x) @@ -162,9 +167,9 @@ def test_linear_eager_vs_hp( assert x_g_sqnr >= 8.0 -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@skip_if_xpu("XPU enablement in progress") @pytest.mark.skipif( - not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8" + torch.cuda.is_available() and not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8" ) @pytest.mark.parametrize( "recipe_name", @@ -177,11 +182,11 @@ def test_linear_eager_vs_hp( def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn): M, K, N = mkn - x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda").requires_grad_() + x = torch.randn(M, K, dtype=torch.bfloat16, device=_DEVICE).requires_grad_() x_copy = copy.deepcopy(x) - g = torch.randn(M, N, device="cuda", dtype=torch.bfloat16) + g = torch.randn(M, N, device=_DEVICE, dtype=torch.bfloat16) m_emulated = nn.Sequential( - nn.Linear(K, N, bias=False, device="cuda", dtype=torch.bfloat16), + nn.Linear(K, N, bias=False, device=_DEVICE, dtype=torch.bfloat16), ) m_real = copy.deepcopy(m_emulated) @@ -211,26 +216,24 @@ def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn): # TODO(future): enable compile support -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_activation_checkpointing(): input_shape = (16, 4) grad_shape = (16, 8) elem_dtype = torch.float8_e4m3fn m = nn.Sequential( - nn.Linear(4, 8, bias=True, device="cuda"), - nn.Linear(8, 8, bias=True, device="cuda"), + nn.Linear(4, 8, bias=True, device=_DEVICE), + nn.Linear(8, 8, bias=True, device=_DEVICE), ) config = MXLinearConfig(block_size=4, elem_dtype=elem_dtype) quantize_(m, config=config) - x = torch.randn(*input_shape, device="cuda").requires_grad_() - g = torch.randn(*grad_shape, device="cuda") + x = torch.randn(*input_shape, device=_DEVICE).requires_grad_() + g = torch.randn(*grad_shape, device=_DEVICE) y = torch.utils.checkpoint.checkpoint(m, x, use_reentrant=False) y.backward(g) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("hp_dtype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize( "recipe_name", @@ -311,7 +314,7 @@ def test_linear_compile( input_shape = (M, K) grad_shape = (M, N) m_mx = nn.Sequential( - nn.Linear(K, N, bias=bias, device="cuda", dtype=hp_dtype), + nn.Linear(K, N, bias=bias, device=_DEVICE, dtype=hp_dtype), ) config = MXLinearConfig.from_recipe_name(recipe_name) config.mxfp8_cast_kernel_choice = mxfp8_cast_kernel_choice @@ -321,9 +324,9 @@ def test_linear_compile( m_mx_c = copy.deepcopy(m_mx) m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor") - x_ref = torch.randn(*input_shape, device="cuda", dtype=hp_dtype).requires_grad_() + x_ref = torch.randn(*input_shape, device=_DEVICE, dtype=hp_dtype).requires_grad_() x = copy.deepcopy(x_ref) - g = torch.randn(*grad_shape, device="cuda", dtype=hp_dtype) + g = torch.randn(*grad_shape, device=_DEVICE, dtype=hp_dtype) y_ref = m_mx(x_ref) y = m_mx_c(x) From 61e8efd719d999f67404dff659b3bc7c2363f62e Mon Sep 17 00:00:00 2001 From: "Sun, Diwei" Date: Tue, 19 Aug 2025 07:55:07 +0000 Subject: [PATCH 4/4] xpu ut enabling: test_mx_mm --- test/prototype/mx_formats/test_mx_mm.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_mm.py b/test/prototype/mx_formats/test_mx_mm.py index 46380cfb55..636ddbf244 100644 --- a/test/prototype/mx_formats/test_mx_mm.py +++ b/test/prototype/mx_formats/test_mx_mm.py @@ -13,17 +13,22 @@ from torchao.prototype.mx_formats.mx_tensor import MXTensor from torchao.prototype.mx_formats.utils import to_blocked from torchao.utils import ( + auto_detect_device, TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_100, ) +from torchao.testing.utils import skip_if_xpu + +_DEVICE = auto_detect_device() + if not TORCH_VERSION_AT_LEAST_2_8: pytest.skip("Unsupported PyTorch version", allow_module_level=True) def run_matrix_test(M: int, K: int, N: int, format) -> float: dtype = torch.bfloat16 - device = torch.device("cuda") + device = torch.device(_DEVICE) a = torch.rand((M, K), dtype=dtype, device=device) b = torch.rand((N, K), dtype=dtype, device=device) @@ -57,9 +62,9 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float: return compute_error(out_hp, out).item() -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@skip_if_xpu("XPU enablement in Progress") @pytest.mark.skipif( - not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8" + torch.cuda.is_available() and not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8" ) @pytest.mark.parametrize( "size",