diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index d04a67771d..e553946413 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -35,7 +35,6 @@ get_bits, pack_uint4, pack_uint6, - triton_f4_to_bf16, triton_f6_e2m3_to_bf16, triton_f6_e3m2_to_bf16, triton_to_mxfp8_dim1, @@ -327,17 +326,6 @@ def test_fp4_pack_unpack(): assert torch.all(orig_vals_dq == orig_vals) -# TODO(future PR): fix or delete this test -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") -@pytest.mark.skipif(is_sm_at_least_89(), reason="broken on CUDA capability 8.9+") -def test_fp4_triton_unscaled_cast(): - packed_vals = torch.arange(0, 255, dtype=torch.uint8, device="cuda") - f32_ref = f4_unpacked_to_f32(unpack_uint4(packed_vals)) - f32_triton = triton_f4_to_bf16(packed_vals).to(torch.float) - assert torch.all(torch.eq(f32_ref, f32_triton)) - - # TODO(future PR): fix or delete this test @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index cabb61276a..732af4df2a 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -196,55 +196,6 @@ def _fp4_packed_to_bf16( output = output.to(tl.bfloat16) return output - @triton.jit - def triton_f4_to_bf16_kernel( - x_ptr, - output_ptr, - n_elements_in, - sign_mask_f4: tl.constexpr, - mantissa_mask_f4: tl.constexpr, - mbits_f4_e2m1: tl.constexpr, - ebits_f4_e2m1: tl.constexpr, - f4_e2m1_exp_bias: tl.constexpr, - mbits_f32: tl.constexpr, - ebits_f32: tl.constexpr, - f32_exp_bias: tl.constexpr, - zero_bits_f32: tl.constexpr, - zero_point_five_bits_f32: tl.constexpr, - BLOCK_SIZE_IN: tl.constexpr, - ): - pid = tl.program_id(axis=0) - n_elements_out = n_elements_in * 2 - BLOCK_SIZE_OUT: tl.constexpr = BLOCK_SIZE_IN * 2 - - block_start_in = pid * BLOCK_SIZE_IN - offsets_in = block_start_in + tl.arange(0, BLOCK_SIZE_IN) - - mask_in = offsets_in < n_elements_in - - # packed uint8 - x_packed = tl.load(x_ptr + offsets_in, mask=mask_in) - output = _fp4_packed_to_bf16( - x_packed, - sign_mask_f4, - mantissa_mask_f4, - mbits_f4_e2m1, - ebits_f4_e2m1, - f4_e2m1_exp_bias, - mbits_f32, - ebits_f32, - f32_exp_bias, - zero_bits_f32, - zero_point_five_bits_f32, - ) - - # set up output offsets - block_start_out = pid * BLOCK_SIZE_OUT - offsets_out = block_start_out + tl.arange(0, BLOCK_SIZE_OUT) - mask_out = offsets_out < n_elements_out - - tl.store(output_ptr + offsets_out, output, mask=mask_out) - @triton.autotune( configs=[ triton.Config({"BLOCK_SIZE_IN": 128}), @@ -624,24 +575,6 @@ def triton_pack_uint6_kernel( else: - def triton_f4_to_bf16_kernel( - x_ptr, - output_ptr, - n_elements_in, - sign_mask_f4, - mantissa_mask_f4, - mbits_f4_e2m1, - ebits_f4_e2m1, - f4_e2m1_exp_bias, - mbits_f32, - ebits_f32, - f32_exp_bias, - zero_bits_f32, - zero_point_five_bits_f32, - BLOCK_SIZE_IN, - ): - raise AssertionError("unsupported without triton") - def triton_f4_to_scaled_bf16_kernel( x_ptr, s_ptr, @@ -705,41 +638,6 @@ def triton_pack_uint6_kernel( raise AssertionError("unsupported without triton") -def triton_f4_to_bf16(x: torch.Tensor): - """ - Input: a tensor of packed fp4 values - Output: a tensor of bfloat16 values - - Note: this function is only used in testing, so we can test - the numerical correctness of the cast without the scaling. - """ - new_shape = (*x.shape[:-1], x.shape[-1] * 2) - output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) - assert x.is_contiguous() - assert x.is_cuda and output.is_cuda - n_elements_in = x.numel() - grid = lambda meta: ( # noqa: E731 - triton.cdiv(n_elements_in, meta["BLOCK_SIZE_IN"]), - ) # noqa: E731,E501 - triton_f4_to_bf16_kernel[grid]( - x, - output, - n_elements_in, - sign_mask_f4=SIGN_MASK_F4, - mantissa_mask_f4=MANTISSA_MASK_F4, - mbits_f4_e2m1=MBITS_F4_E2M1, - ebits_f4_e2m1=EBITS_F4_E2M1, - f4_e2m1_exp_bias=F4_E2M1_EXP_BIAS, - mbits_f32=MBITS_F32, - ebits_f32=EBITS_F32, - f32_exp_bias=F32_EXP_BIAS, - zero_bits_f32=ZERO_BITS_F32, - zero_point_five_bits_f32=ZERO_POINT_FIVE_BITS_F32, - BLOCK_SIZE_IN=512, - ) - return output - - def triton_f4_to_scaled_bf16( x: torch.Tensor, s_e8m0: torch.Tensor,