Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions test/prototype/mx_formats/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
get_bits,
pack_uint4,
pack_uint6,
triton_f4_to_bf16,
triton_f6_e2m3_to_bf16,
triton_f6_e3m2_to_bf16,
triton_to_mxfp8_dim1,
Expand Down Expand Up @@ -327,17 +326,6 @@ def test_fp4_pack_unpack():
assert torch.all(orig_vals_dq == orig_vals)


# TODO(future PR): fix or delete this test
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(is_sm_at_least_89(), reason="broken on CUDA capability 8.9+")
def test_fp4_triton_unscaled_cast():
packed_vals = torch.arange(0, 255, dtype=torch.uint8, device="cuda")
f32_ref = f4_unpacked_to_f32(unpack_uint4(packed_vals))
f32_triton = triton_f4_to_bf16(packed_vals).to(torch.float)
assert torch.all(torch.eq(f32_ref, f32_triton))


# TODO(future PR): fix or delete this test
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
Expand Down
102 changes: 0 additions & 102 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,55 +196,6 @@ def _fp4_packed_to_bf16(
output = output.to(tl.bfloat16)
return output

@triton.jit
def triton_f4_to_bf16_kernel(
x_ptr,
output_ptr,
n_elements_in,
sign_mask_f4: tl.constexpr,
mantissa_mask_f4: tl.constexpr,
mbits_f4_e2m1: tl.constexpr,
ebits_f4_e2m1: tl.constexpr,
f4_e2m1_exp_bias: tl.constexpr,
mbits_f32: tl.constexpr,
ebits_f32: tl.constexpr,
f32_exp_bias: tl.constexpr,
zero_bits_f32: tl.constexpr,
zero_point_five_bits_f32: tl.constexpr,
BLOCK_SIZE_IN: tl.constexpr,
):
pid = tl.program_id(axis=0)
n_elements_out = n_elements_in * 2
BLOCK_SIZE_OUT: tl.constexpr = BLOCK_SIZE_IN * 2

block_start_in = pid * BLOCK_SIZE_IN
offsets_in = block_start_in + tl.arange(0, BLOCK_SIZE_IN)

mask_in = offsets_in < n_elements_in

# packed uint8
x_packed = tl.load(x_ptr + offsets_in, mask=mask_in)
output = _fp4_packed_to_bf16(
x_packed,
sign_mask_f4,
mantissa_mask_f4,
mbits_f4_e2m1,
ebits_f4_e2m1,
f4_e2m1_exp_bias,
mbits_f32,
ebits_f32,
f32_exp_bias,
zero_bits_f32,
zero_point_five_bits_f32,
)

# set up output offsets
block_start_out = pid * BLOCK_SIZE_OUT
offsets_out = block_start_out + tl.arange(0, BLOCK_SIZE_OUT)
mask_out = offsets_out < n_elements_out

tl.store(output_ptr + offsets_out, output, mask=mask_out)

@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE_IN": 128}),
Expand Down Expand Up @@ -624,24 +575,6 @@ def triton_pack_uint6_kernel(

else:

def triton_f4_to_bf16_kernel(
x_ptr,
output_ptr,
n_elements_in,
sign_mask_f4,
mantissa_mask_f4,
mbits_f4_e2m1,
ebits_f4_e2m1,
f4_e2m1_exp_bias,
mbits_f32,
ebits_f32,
f32_exp_bias,
zero_bits_f32,
zero_point_five_bits_f32,
BLOCK_SIZE_IN,
):
raise AssertionError("unsupported without triton")

def triton_f4_to_scaled_bf16_kernel(
x_ptr,
s_ptr,
Expand Down Expand Up @@ -705,41 +638,6 @@ def triton_pack_uint6_kernel(
raise AssertionError("unsupported without triton")


def triton_f4_to_bf16(x: torch.Tensor):
"""
Input: a tensor of packed fp4 values
Output: a tensor of bfloat16 values

Note: this function is only used in testing, so we can test
the numerical correctness of the cast without the scaling.
"""
new_shape = (*x.shape[:-1], x.shape[-1] * 2)
output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16)
assert x.is_contiguous()
assert x.is_cuda and output.is_cuda
n_elements_in = x.numel()
grid = lambda meta: ( # noqa: E731
triton.cdiv(n_elements_in, meta["BLOCK_SIZE_IN"]),
) # noqa: E731,E501
triton_f4_to_bf16_kernel[grid](
x,
output,
n_elements_in,
sign_mask_f4=SIGN_MASK_F4,
mantissa_mask_f4=MANTISSA_MASK_F4,
mbits_f4_e2m1=MBITS_F4_E2M1,
ebits_f4_e2m1=EBITS_F4_E2M1,
f4_e2m1_exp_bias=F4_E2M1_EXP_BIAS,
mbits_f32=MBITS_F32,
ebits_f32=EBITS_F32,
f32_exp_bias=F32_EXP_BIAS,
zero_bits_f32=ZERO_BITS_F32,
zero_point_five_bits_f32=ZERO_POINT_FIVE_BITS_F32,
BLOCK_SIZE_IN=512,
)
return output


def triton_f4_to_scaled_bf16(
x: torch.Tensor,
s_e8m0: torch.Tensor,
Expand Down
Loading