Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
52a69b7
bug fixed, test added
vthumbe1503 Nov 26, 2025
5547e4c
fix contigous
vthumbe1503 Nov 26, 2025
681ad87
Merge branch 'NVIDIA:main' into fix_split_tensor_bug
vthumbe1503 Nov 26, 2025
0b8ccbb
revert unecessary change
vthumbe1503 Nov 26, 2025
8adcbb3
revert another change
vthumbe1503 Nov 26, 2025
18e85ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 26, 2025
0ce6c7e
address review comments
vthumbe1503 Nov 26, 2025
d582d2b
Merge branch 'main' into fix_split_tensor_bug
vthumbe1503 Nov 26, 2025
bbc669a
Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
vthumbe1503 Nov 26, 2025
01b1cc2
Merge branch 'main' into fix_split_tensor_bug
vthumbe1503 Nov 30, 2025
fcb4c9f
address review comments
vthumbe1503 Dec 4, 2025
6a61b08
missed adding renamed file
vthumbe1503 Dec 4, 2025
2fd860b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 4, 2025
7fc1c6f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 4, 2025
107b461
Merge branch 'main' into fix_split_tensor_bug
vthumbe1503 Dec 4, 2025
afbed3d
Merge branch 'main' into fix_split_tensor_bug
vthumbe1503 Dec 7, 2025
f84e3f6
fix minor issue
vthumbe1503 Dec 7, 2025
50b7c74
fix ci issue
vthumbe1503 Dec 8, 2025
492156c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 8, 2025
5ccda23
fix the test for bfloat16
vthumbe1503 Dec 8, 2025
67b1a93
Merge branch 'fix_split_tensor_bug' of github.com:vthumbe1503/Transfo…
vthumbe1503 Dec 8, 2025
4cd7895
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2976,3 +2976,75 @@ def _run_module(m, inp):
out = _run_module(g2, b)

assert_allclose(out, outT, 1e-7)


@pytest.mark.parametrize(
"quantization_type",
[
"fp8",
"mxfp8",
],
)
@pytest.mark.parametrize(
"shape,chunks,dim",
[
((64, 128), 2, 0), # Split along first dimension, needs padding for mxfp8
((64, 128), 2, 1), # Split along second dimension, goes down dequantization path for mxfp8
],
)
def test_fp8_split_functionality(quantization_type, shape, chunks, dim):
"""Test torch.chunk on FP8 and MXFP8 tensors and verify correctness via dequantization."""
if quantization_type == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization_type == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)

device = "cuda"
dtype = torch.bfloat16

# Create reference tensor
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
ref_tensor = torch.randn(shape, device=device, dtype=dtype)

# Quantize the tensor
if quantization_type == "fp8":
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=device).squeeze(),
amax=torch.zeros(1, dtype=torch.float32, device=device),
fp8_dtype=tex.DType.kFloat8E4M3,
)
quantized_tensor = quantizer(ref_tensor)
elif quantization_type == "mxfp8":
quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)
quantized_tensor = quantizer(ref_tensor)

# Apply torch.chunk on quantized tensor
quantized_tensor_dispatch_out = torch.chunk(quantized_tensor, chunks, dim=dim)
# need to make tensor contigous for dim=1 splitting.
outs = [out.contiguous() for out in quantized_tensor_dispatch_out]
if dim == 0 or quantization_type == "fp8":
# Dequantize the chunked results
chunked_dequantized = [chunk.dequantize() for chunk in outs]
else:
# When splitting along second dimension, we go down dequantization
# route in case of mxfp8 for now.
chunked_dequantized = outs

# Reference: chunk the dequantized tensor directly
ref_dequantized = quantized_tensor.dequantize()
ref_chunked = torch.chunk(ref_dequantized, chunks, dim=dim)

# Compare results
assert len(chunked_dequantized) == len(
ref_chunked
), f"Number of chunks mismatch: {len(chunked_dequantized)} vs {len(ref_chunked)}"

for i, (chunk_deq, ref_chunk) in enumerate(zip(chunked_dequantized, ref_chunked)):
assert (
chunk_deq.shape == ref_chunk.shape
), f"Chunk {i} shape mismatch: {chunk_deq.shape} vs {ref_chunk.shape}"
torch.testing.assert_close(
chunk_deq,
ref_chunk,
)
6 changes: 5 additions & 1 deletion transformer_engine/pytorch/tensor/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,11 @@ def contiguous(
memory_format=memory_format
):
return self
return Float8Tensor.make_like(tensor=self, data=self._data.contiguous())
return Float8Tensor.make_like(
tensor=self,
data=self._data.contiguous(),
data_transpose=self._transpose.contiguous() if self._transpose is not None else None,
)

# raise ValueError("Float8Tensor does not support different memory formats!")

Expand Down
13 changes: 8 additions & 5 deletions transformer_engine/pytorch/tensor/mxfp8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,13 +434,16 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
if scale_inv is not None
else None
)
scale_inv_out = list(scale_inv_out) if scale_inv_out is not None else None
# Pad scale_inv_out to be a multiple of pad_multiple
if scale_inv_out is not None:
current_shape = scale_inv_out.shape
pad_dim0 = (pad_multiple - current_shape[0] % pad_multiple) % pad_multiple
if pad_dim0 > 0:
scale_inv_out = torch.nn.functional.pad(scale_inv_out, (0, 0, 0, pad_dim0))

for idx, split_scale_inv_out in enumerate(scale_inv_out):
current_shape = split_scale_inv_out.shape
pad_dim0 = (pad_multiple - current_shape[0] % pad_multiple) % pad_multiple
if pad_dim0 > 0:
scale_inv_out[idx] = torch.nn.functional.pad(
split_scale_inv_out, (0, 0, 0, pad_dim0)
)
out_data.append(scale_inv_out)
return [
MXFP8Tensor(
Expand Down
Loading