Skip to content

Conversation

@vthumbe1503
Copy link
Collaborator

@vthumbe1503 vthumbe1503 commented Nov 26, 2025

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes #2422

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • We were retrieving the shape from a list of splitted scale inverse tensors rather than the splitted scale inverse tensors themselves fo MXFP8. Fixed it now. Also added the unit test fot the same.
  • Changed contiguous API for float8 tensor to also handle transpose for L40/Hopper. Also fixed the issue where requires_grad should be maintained on the tensor after calling contigouous on it.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Nov 26, 2025

Greptile Overview

Greptile Summary

This PR fixes a critical bug in MXFP8 tensor splitting that caused AttributeError: 'list' object has no attribute 'shape' during model checkpointing. The issue occurred because torch.split returns a list of tensors, but the code incorrectly tried to access .shape on the list directly instead of iterating over individual tensors.

Key Changes:

  • Fixed mxfp8_tensor.py line 437: converts scale_inv_out to list explicitly and iterates over each split tensor for padding (lines 440-446)
  • Modified float8_tensor.py dequantize() to call contiguous() first to handle transpose for L40/Hopper
  • Updated contiguous() to make both _data and _transpose contiguous (if present) and preserve requires_grad
  • Added comprehensive test coverage for torch.chunk on all quantized tensor types including MXFP8

Impact:
The MXFP8 split tensor bug fix directly resolves issue #2422, allowing checkpoint saving to work correctly with MXFP8 tensors in distributed training scenarios.

Confidence Score: 4/5

  • This PR is safe to merge with minor performance consideration
  • Score of 4 reflects that the core bug fix is correct and well-tested, but the contiguous() method change removes an early-return optimization that could cause unnecessary tensor copies when data is already contiguous
  • Pay attention to transformer_engine/pytorch/tensor/float8_tensor.py - the contiguous() method performance impact should be evaluated

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/tensor/mxfp8_tensor.py 4/5 Fixed bug where scale_inv_out.shape was called on a list returned by torch.split, now correctly iterates over split tensors and pads each individually
transformer_engine/pytorch/tensor/float8_tensor.py 3/5 Modified dequantize() to call contiguous() first, and changed contiguous() to always create new tensor with both data and transpose contiguous, but this removes early-return optimization
tests/pytorch/test_quantized_tensor.py 5/5 New comprehensive test file covering all quantized tensor types including MXFP8 chunk test that validates the bug fix
qa/L0_pytorch_unittest/test.sh 5/5 Updated test runner to use renamed test file test_quantized_tensor.py instead of test_float8tensor.py

Sequence Diagram

sequenceDiagram
    participant User
    participant torch
    participant MXFP8Tensor
    participant Float8Tensor
    
    User->>torch: torch.chunk(mxfp8_tensor, 2, dim=0)
    torch->>MXFP8Tensor: __torch_dispatch__(aten.split.Tensor)
    
    Note over MXFP8Tensor: Split rowwise_data and columnwise_data
    MXFP8Tensor->>MXFP8Tensor: split _rowwise_data and _columnwise_data
    
    Note over MXFP8Tensor: Split scale_inv tensors (BUG WAS HERE)
    MXFP8Tensor->>MXFP8Tensor: split _rowwise_scale_inv
    MXFP8Tensor->>MXFP8Tensor: split _columnwise_scale_inv
    
    Note over MXFP8Tensor: OLD: scale_inv_out.shape ❌<br/>NEW: for each in scale_inv_out ✓
    
    loop For each split scale_inv tensor
        MXFP8Tensor->>MXFP8Tensor: pad split_scale_inv_out if needed
    end
    
    MXFP8Tensor->>MXFP8Tensor: Create MXFP8Tensor for each split
    MXFP8Tensor-->>torch: List[MXFP8Tensor]
    torch-->>User: Tuple of split tensors
    
    Note over User: Checkpoint save succeeds ✓
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

current_shape = split_scale_inv_out.shape
pad_dim0 = (pad_multiple - current_shape[0] % pad_multiple) % pad_multiple
if pad_dim0 > 0:
scale_inv_out[idx] = torch.nn.functional.pad(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Potential tuple mutability issue: if __torch_dispatch__ returns a tuple (immutable), this item assignment will fail with TypeError. Consider converting to list first:

Suggested change
scale_inv_out[idx] = torch.nn.functional.pad(
scale_inv_out = list(scale_inv_out) if isinstance(scale_inv_out, tuple) else scale_inv_out
scale_inv_out[idx] = torch.nn.functional.pad(

Or convert immediately after dispatch on line 427.

@vthumbe1503
Copy link
Collaborator Author

/te-ci pytorch

Signed-off-by: Varun Thumbe <[email protected]>
@vthumbe1503 vthumbe1503 changed the title [Pytorch][Bug]MXFP8 Split tensor had bug in padding the scale inverse tensors. [Pytorch][Bug]MXFP8 Split tensor Bug fix Nov 26, 2025
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@vthumbe1503
Copy link
Collaborator Author

/te-ci pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: vthumbe1503 <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

((64, 128), 2, 1), # Split along second dimension, goes down dequantization path for mxfp8
],
)
def test_fp8_split_functionality(quantization_type, shape, chunks, dim):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_numerics.py is mostly focused on modules. This would be a better fit for test_float8tensor.py, which is more focused on granular functionality in Float8Tensor. Instead of creating a new file for MXFP8Tensor (very clunky, especially as we add more tensor classes in the future), I propose renaming that file to test_quantized_tensor.py so we have a common place for general tests like this.

Comment on lines 2981 to 3050
@pytest.mark.parametrize(
"quantization_type",
[
"fp8",
"mxfp8",
],
)
@pytest.mark.parametrize(
"shape,chunks,dim",
[
((64, 128), 2, 0), # Split along first dimension, needs padding for mxfp8
((64, 128), 2, 1), # Split along second dimension, goes down dequantization path for mxfp8
],
)
def test_fp8_split_functionality(quantization_type, shape, chunks, dim):
"""Test torch.chunk on FP8 and MXFP8 tensors and verify correctness via dequantization."""
if quantization_type == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization_type == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)

device = "cuda"
dtype = torch.bfloat16

# Create reference tensor
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
ref_tensor = torch.randn(shape, device=device, dtype=dtype)

# Quantize the tensor
if quantization_type == "fp8":
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=device).squeeze(),
amax=torch.zeros(1, dtype=torch.float32, device=device),
fp8_dtype=tex.DType.kFloat8E4M3,
)
quantized_tensor = quantizer(ref_tensor)
elif quantization_type == "mxfp8":
quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)
quantized_tensor = quantizer(ref_tensor)

# Apply torch.chunk on quantized tensor
quantized_tensor_dispatch_out = torch.chunk(quantized_tensor, chunks, dim=dim)
# need to make tensor contigous for dim=1 splitting.
outs = [out.contiguous() for out in quantized_tensor_dispatch_out]
if dim == 0 or quantization_type == "fp8":
# Dequantize the chunked results
chunked_dequantized = [chunk.dequantize() for chunk in outs]
else:
# When splitting along second dimension, we go down dequantization
# route in case of mxfp8 for now.
chunked_dequantized = outs

# Reference: chunk the dequantized tensor directly
ref_dequantized = quantized_tensor.dequantize()
ref_chunked = torch.chunk(ref_dequantized, chunks, dim=dim)

# Compare results
assert len(chunked_dequantized) == len(
ref_chunked
), f"Number of chunks mismatch: {len(chunked_dequantized)} vs {len(ref_chunked)}"

for i, (chunk_deq, ref_chunk) in enumerate(zip(chunked_dequantized, ref_chunked)):
assert (
chunk_deq.shape == ref_chunk.shape
), f"Chunk {i} shape mismatch: {chunk_deq.shape} vs {ref_chunk.shape}"
torch.testing.assert_close(
chunk_deq,
ref_chunk,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can easily make this in to a general QuantizedTensor test with minimal knowledge of the internal implementation:

Suggested change
@pytest.mark.parametrize(
"quantization_type",
[
"fp8",
"mxfp8",
],
)
@pytest.mark.parametrize(
"shape,chunks,dim",
[
((64, 128), 2, 0), # Split along first dimension, needs padding for mxfp8
((64, 128), 2, 1), # Split along second dimension, goes down dequantization path for mxfp8
],
)
def test_fp8_split_functionality(quantization_type, shape, chunks, dim):
"""Test torch.chunk on FP8 and MXFP8 tensors and verify correctness via dequantization."""
if quantization_type == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization_type == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
device = "cuda"
dtype = torch.bfloat16
# Create reference tensor
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
ref_tensor = torch.randn(shape, device=device, dtype=dtype)
# Quantize the tensor
if quantization_type == "fp8":
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=device).squeeze(),
amax=torch.zeros(1, dtype=torch.float32, device=device),
fp8_dtype=tex.DType.kFloat8E4M3,
)
quantized_tensor = quantizer(ref_tensor)
elif quantization_type == "mxfp8":
quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)
quantized_tensor = quantizer(ref_tensor)
# Apply torch.chunk on quantized tensor
quantized_tensor_dispatch_out = torch.chunk(quantized_tensor, chunks, dim=dim)
# need to make tensor contigous for dim=1 splitting.
outs = [out.contiguous() for out in quantized_tensor_dispatch_out]
if dim == 0 or quantization_type == "fp8":
# Dequantize the chunked results
chunked_dequantized = [chunk.dequantize() for chunk in outs]
else:
# When splitting along second dimension, we go down dequantization
# route in case of mxfp8 for now.
chunked_dequantized = outs
# Reference: chunk the dequantized tensor directly
ref_dequantized = quantized_tensor.dequantize()
ref_chunked = torch.chunk(ref_dequantized, chunks, dim=dim)
# Compare results
assert len(chunked_dequantized) == len(
ref_chunked
), f"Number of chunks mismatch: {len(chunked_dequantized)} vs {len(ref_chunked)}"
for i, (chunk_deq, ref_chunk) in enumerate(zip(chunked_dequantized, ref_chunked)):
assert (
chunk_deq.shape == ref_chunk.shape
), f"Chunk {i} shape mismatch: {chunk_deq.shape} vs {ref_chunk.shape}"
torch.testing.assert_close(
chunk_deq,
ref_chunk,
)
@pytest.mark.parametrize(
"quantization", ["fp8", "mxfp8", "fp8_blockwise", "nvfp4"],
)
@pytest.mark.parametrize("dim", [0, 1])
def test_chunk(
*,
quantization: str,
shape: Iterable[int] = (128, 128),
chunks: int = 2,
dim: int,
dtype: torch.dtype = torch.bfloat16,
device: torch.device = "cuda",
) -> None:
# Skip invalid configs
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "fp8_blockwise" and not fp8_blockwise_available:
pytest.skip(reason_for_no_fp8_blockwise)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if quantization == "nvfp4" and not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)
# Create quantizer
if quantization == "fp8":
quantizer = Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3)
elif quantization == "mxfp8":
quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)
elif quantization == "fp8_blockwise":
quantizer = Float8BlockQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
block_scaling_dim=1,
)
elif quantization == "nvfp4":
quantizer = NVFP4Quantizer(
with_rht=False,
with_post_rht_amax=False,
with_2d_quantization=False,
stochastic_rounding=False,
with_random_sign_mask=False,
)
else:
raise ValueError(f"Unknown quantizer ({quantizer})")
# Create reference and quantized tensor
ref_tensor = torch.randn(shape, device=device, dtype=dtype)
quantized_tensor = quantizer(ref_tensor)
ref_tensor.copy_(quantized_tensor)
# Chunk tensors
ref_splits = torch.chunk(ref_tensor, chunks, dim=dim)
quantized_splits = torch.chunk(quantized_splits, chunks, dim=dim)
# Check splits
for ref_split, quantized_split in zip(ref_splits, quantized_splits):
# Check split shapes
assert ref_split.size() == ref_chunk.size()
# Check that splits are quantized when expected
if quantization == "fp8":
assert isinstance(quantized_split, Float8Tensor)
if quantization == "mxfp8" and dim == 0:
assert isinstance(quantized_split, MXFP8Tensor)
# Check values
torch.testing.assert_close(quantized_split, ref_split)

# Apply torch.chunk on quantized tensor
quantized_tensor_dispatch_out = torch.chunk(quantized_tensor, chunks, dim=dim)
# need to make tensor contigous for dim=1 splitting.
outs = [out.contiguous() for out in quantized_tensor_dispatch_out]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this contiguous into dequantize? Normal torch.Tensors work without it, so it's a problem in our implementation if we force users do this extra unintuitive step.

quantized_tensor_dispatch_out = torch.chunk(quantized_tensor, chunks, dim=dim)
# need to make tensor contigous for dim=1 splitting.
outs = [out.contiguous() for out in quantized_tensor_dispatch_out]
if dim == 0 or quantization_type == "fp8":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stylistic nit: This doesn't generalize to new recipes. If we generalize, we get something like:

Suggested change
if dim == 0 or quantization_type == "fp8":
if quantization_type == "fp8" or (quantization_type == "mxfp8" and dim == 0):

The extra robustness is not that important, but notice how much more readable this is. Basically we are enumerating the "special cases" where we have to dequantize. It's worth putting thought into generalization, even if we never plan on doing it, because it forces you to understand the code at a logical level.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@vthumbe1503
Copy link
Collaborator Author

/te-ci pytorch

@vthumbe1503
Copy link
Collaborator Author

/te-ci pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/pytorch/tensor/float8_tensor.py, line 557-565 (link)

    style: always creates new tensor even when already contiguous - consider early return optimization:

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@vthumbe1503
Copy link
Collaborator Author

/te-ci pytorch

Signed-off-by: Varun Thumbe <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Varun Thumbe <[email protected]>
@vthumbe1503
Copy link
Collaborator Author

/te-ci pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

AttributeError on mxfp8_tensor while checkpointing a model

2 participants