-
Notifications
You must be signed in to change notification settings - Fork 570
[Pytorch][Bug]MXFP8 Split tensor Bug fix #2427
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR fixes a critical bug in MXFP8 tensor splitting that caused Key Changes:
Impact: Confidence Score: 4/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User
participant torch
participant MXFP8Tensor
participant Float8Tensor
User->>torch: torch.chunk(mxfp8_tensor, 2, dim=0)
torch->>MXFP8Tensor: __torch_dispatch__(aten.split.Tensor)
Note over MXFP8Tensor: Split rowwise_data and columnwise_data
MXFP8Tensor->>MXFP8Tensor: split _rowwise_data and _columnwise_data
Note over MXFP8Tensor: Split scale_inv tensors (BUG WAS HERE)
MXFP8Tensor->>MXFP8Tensor: split _rowwise_scale_inv
MXFP8Tensor->>MXFP8Tensor: split _columnwise_scale_inv
Note over MXFP8Tensor: OLD: scale_inv_out.shape ❌<br/>NEW: for each in scale_inv_out ✓
loop For each split scale_inv tensor
MXFP8Tensor->>MXFP8Tensor: pad split_scale_inv_out if needed
end
MXFP8Tensor->>MXFP8Tensor: Create MXFP8Tensor for each split
MXFP8Tensor-->>torch: List[MXFP8Tensor]
torch-->>User: Tuple of split tensors
Note over User: Checkpoint save succeeds ✓
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 1 comment
| current_shape = split_scale_inv_out.shape | ||
| pad_dim0 = (pad_multiple - current_shape[0] % pad_multiple) % pad_multiple | ||
| if pad_dim0 > 0: | ||
| scale_inv_out[idx] = torch.nn.functional.pad( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Potential tuple mutability issue: if __torch_dispatch__ returns a tuple (immutable), this item assignment will fail with TypeError. Consider converting to list first:
| scale_inv_out[idx] = torch.nn.functional.pad( | |
| scale_inv_out = list(scale_inv_out) if isinstance(scale_inv_out, tuple) else scale_inv_out | |
| scale_inv_out[idx] = torch.nn.functional.pad( |
Or convert immediately after dispatch on line 427.
|
/te-ci pytorch |
Signed-off-by: Varun Thumbe <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
|
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 1 comment
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
tests/pytorch/test_numerics.py
Outdated
| ((64, 128), 2, 1), # Split along second dimension, goes down dequantization path for mxfp8 | ||
| ], | ||
| ) | ||
| def test_fp8_split_functionality(quantization_type, shape, chunks, dim): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_numerics.py is mostly focused on modules. This would be a better fit for test_float8tensor.py, which is more focused on granular functionality in Float8Tensor. Instead of creating a new file for MXFP8Tensor (very clunky, especially as we add more tensor classes in the future), I propose renaming that file to test_quantized_tensor.py so we have a common place for general tests like this.
tests/pytorch/test_numerics.py
Outdated
| @pytest.mark.parametrize( | ||
| "quantization_type", | ||
| [ | ||
| "fp8", | ||
| "mxfp8", | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize( | ||
| "shape,chunks,dim", | ||
| [ | ||
| ((64, 128), 2, 0), # Split along first dimension, needs padding for mxfp8 | ||
| ((64, 128), 2, 1), # Split along second dimension, goes down dequantization path for mxfp8 | ||
| ], | ||
| ) | ||
| def test_fp8_split_functionality(quantization_type, shape, chunks, dim): | ||
| """Test torch.chunk on FP8 and MXFP8 tensors and verify correctness via dequantization.""" | ||
| if quantization_type == "fp8" and not fp8_available: | ||
| pytest.skip(reason_for_no_fp8) | ||
| if quantization_type == "mxfp8" and not mxfp8_available: | ||
| pytest.skip(reason_for_no_mxfp8) | ||
|
|
||
| device = "cuda" | ||
| dtype = torch.bfloat16 | ||
|
|
||
| # Create reference tensor | ||
| torch.manual_seed(1234) | ||
| torch.cuda.manual_seed(1234) | ||
| ref_tensor = torch.randn(shape, device=device, dtype=dtype) | ||
|
|
||
| # Quantize the tensor | ||
| if quantization_type == "fp8": | ||
| quantizer = Float8Quantizer( | ||
| scale=torch.ones(1, dtype=torch.float32, device=device).squeeze(), | ||
| amax=torch.zeros(1, dtype=torch.float32, device=device), | ||
| fp8_dtype=tex.DType.kFloat8E4M3, | ||
| ) | ||
| quantized_tensor = quantizer(ref_tensor) | ||
| elif quantization_type == "mxfp8": | ||
| quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) | ||
| quantized_tensor = quantizer(ref_tensor) | ||
|
|
||
| # Apply torch.chunk on quantized tensor | ||
| quantized_tensor_dispatch_out = torch.chunk(quantized_tensor, chunks, dim=dim) | ||
| # need to make tensor contigous for dim=1 splitting. | ||
| outs = [out.contiguous() for out in quantized_tensor_dispatch_out] | ||
| if dim == 0 or quantization_type == "fp8": | ||
| # Dequantize the chunked results | ||
| chunked_dequantized = [chunk.dequantize() for chunk in outs] | ||
| else: | ||
| # When splitting along second dimension, we go down dequantization | ||
| # route in case of mxfp8 for now. | ||
| chunked_dequantized = outs | ||
|
|
||
| # Reference: chunk the dequantized tensor directly | ||
| ref_dequantized = quantized_tensor.dequantize() | ||
| ref_chunked = torch.chunk(ref_dequantized, chunks, dim=dim) | ||
|
|
||
| # Compare results | ||
| assert len(chunked_dequantized) == len( | ||
| ref_chunked | ||
| ), f"Number of chunks mismatch: {len(chunked_dequantized)} vs {len(ref_chunked)}" | ||
|
|
||
| for i, (chunk_deq, ref_chunk) in enumerate(zip(chunked_dequantized, ref_chunked)): | ||
| assert ( | ||
| chunk_deq.shape == ref_chunk.shape | ||
| ), f"Chunk {i} shape mismatch: {chunk_deq.shape} vs {ref_chunk.shape}" | ||
| torch.testing.assert_close( | ||
| chunk_deq, | ||
| ref_chunk, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can easily make this in to a general QuantizedTensor test with minimal knowledge of the internal implementation:
| @pytest.mark.parametrize( | |
| "quantization_type", | |
| [ | |
| "fp8", | |
| "mxfp8", | |
| ], | |
| ) | |
| @pytest.mark.parametrize( | |
| "shape,chunks,dim", | |
| [ | |
| ((64, 128), 2, 0), # Split along first dimension, needs padding for mxfp8 | |
| ((64, 128), 2, 1), # Split along second dimension, goes down dequantization path for mxfp8 | |
| ], | |
| ) | |
| def test_fp8_split_functionality(quantization_type, shape, chunks, dim): | |
| """Test torch.chunk on FP8 and MXFP8 tensors and verify correctness via dequantization.""" | |
| if quantization_type == "fp8" and not fp8_available: | |
| pytest.skip(reason_for_no_fp8) | |
| if quantization_type == "mxfp8" and not mxfp8_available: | |
| pytest.skip(reason_for_no_mxfp8) | |
| device = "cuda" | |
| dtype = torch.bfloat16 | |
| # Create reference tensor | |
| torch.manual_seed(1234) | |
| torch.cuda.manual_seed(1234) | |
| ref_tensor = torch.randn(shape, device=device, dtype=dtype) | |
| # Quantize the tensor | |
| if quantization_type == "fp8": | |
| quantizer = Float8Quantizer( | |
| scale=torch.ones(1, dtype=torch.float32, device=device).squeeze(), | |
| amax=torch.zeros(1, dtype=torch.float32, device=device), | |
| fp8_dtype=tex.DType.kFloat8E4M3, | |
| ) | |
| quantized_tensor = quantizer(ref_tensor) | |
| elif quantization_type == "mxfp8": | |
| quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) | |
| quantized_tensor = quantizer(ref_tensor) | |
| # Apply torch.chunk on quantized tensor | |
| quantized_tensor_dispatch_out = torch.chunk(quantized_tensor, chunks, dim=dim) | |
| # need to make tensor contigous for dim=1 splitting. | |
| outs = [out.contiguous() for out in quantized_tensor_dispatch_out] | |
| if dim == 0 or quantization_type == "fp8": | |
| # Dequantize the chunked results | |
| chunked_dequantized = [chunk.dequantize() for chunk in outs] | |
| else: | |
| # When splitting along second dimension, we go down dequantization | |
| # route in case of mxfp8 for now. | |
| chunked_dequantized = outs | |
| # Reference: chunk the dequantized tensor directly | |
| ref_dequantized = quantized_tensor.dequantize() | |
| ref_chunked = torch.chunk(ref_dequantized, chunks, dim=dim) | |
| # Compare results | |
| assert len(chunked_dequantized) == len( | |
| ref_chunked | |
| ), f"Number of chunks mismatch: {len(chunked_dequantized)} vs {len(ref_chunked)}" | |
| for i, (chunk_deq, ref_chunk) in enumerate(zip(chunked_dequantized, ref_chunked)): | |
| assert ( | |
| chunk_deq.shape == ref_chunk.shape | |
| ), f"Chunk {i} shape mismatch: {chunk_deq.shape} vs {ref_chunk.shape}" | |
| torch.testing.assert_close( | |
| chunk_deq, | |
| ref_chunk, | |
| ) | |
| @pytest.mark.parametrize( | |
| "quantization", ["fp8", "mxfp8", "fp8_blockwise", "nvfp4"], | |
| ) | |
| @pytest.mark.parametrize("dim", [0, 1]) | |
| def test_chunk( | |
| *, | |
| quantization: str, | |
| shape: Iterable[int] = (128, 128), | |
| chunks: int = 2, | |
| dim: int, | |
| dtype: torch.dtype = torch.bfloat16, | |
| device: torch.device = "cuda", | |
| ) -> None: | |
| # Skip invalid configs | |
| if quantization == "fp8" and not fp8_available: | |
| pytest.skip(reason_for_no_fp8) | |
| if quantization == "fp8_blockwise" and not fp8_blockwise_available: | |
| pytest.skip(reason_for_no_fp8_blockwise) | |
| if quantization == "mxfp8" and not mxfp8_available: | |
| pytest.skip(reason_for_no_mxfp8) | |
| if quantization == "nvfp4" and not nvfp4_available: | |
| pytest.skip(reason_for_no_nvfp4) | |
| # Create quantizer | |
| if quantization == "fp8": | |
| quantizer = Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3) | |
| elif quantization == "mxfp8": | |
| quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) | |
| elif quantization == "fp8_blockwise": | |
| quantizer = Float8BlockQuantizer( | |
| fp8_dtype=tex.DType.kFloat8E4M3, | |
| block_scaling_dim=1, | |
| ) | |
| elif quantization == "nvfp4": | |
| quantizer = NVFP4Quantizer( | |
| with_rht=False, | |
| with_post_rht_amax=False, | |
| with_2d_quantization=False, | |
| stochastic_rounding=False, | |
| with_random_sign_mask=False, | |
| ) | |
| else: | |
| raise ValueError(f"Unknown quantizer ({quantizer})") | |
| # Create reference and quantized tensor | |
| ref_tensor = torch.randn(shape, device=device, dtype=dtype) | |
| quantized_tensor = quantizer(ref_tensor) | |
| ref_tensor.copy_(quantized_tensor) | |
| # Chunk tensors | |
| ref_splits = torch.chunk(ref_tensor, chunks, dim=dim) | |
| quantized_splits = torch.chunk(quantized_splits, chunks, dim=dim) | |
| # Check splits | |
| for ref_split, quantized_split in zip(ref_splits, quantized_splits): | |
| # Check split shapes | |
| assert ref_split.size() == ref_chunk.size() | |
| # Check that splits are quantized when expected | |
| if quantization == "fp8": | |
| assert isinstance(quantized_split, Float8Tensor) | |
| if quantization == "mxfp8" and dim == 0: | |
| assert isinstance(quantized_split, MXFP8Tensor) | |
| # Check values | |
| torch.testing.assert_close(quantized_split, ref_split) |
tests/pytorch/test_numerics.py
Outdated
| # Apply torch.chunk on quantized tensor | ||
| quantized_tensor_dispatch_out = torch.chunk(quantized_tensor, chunks, dim=dim) | ||
| # need to make tensor contigous for dim=1 splitting. | ||
| outs = [out.contiguous() for out in quantized_tensor_dispatch_out] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move this contiguous into dequantize? Normal torch.Tensors work without it, so it's a problem in our implementation if we force users do this extra unintuitive step.
tests/pytorch/test_numerics.py
Outdated
| quantized_tensor_dispatch_out = torch.chunk(quantized_tensor, chunks, dim=dim) | ||
| # need to make tensor contigous for dim=1 splitting. | ||
| outs = [out.contiguous() for out in quantized_tensor_dispatch_out] | ||
| if dim == 0 or quantization_type == "fp8": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stylistic nit: This doesn't generalize to new recipes. If we generalize, we get something like:
| if dim == 0 or quantization_type == "fp8": | |
| if quantization_type == "fp8" or (quantization_type == "mxfp8" and dim == 0): |
The extra robustness is not that important, but notice how much more readable this is. Basically we are enumerating the "special cases" where we have to dequantize. It's worth putting thought into generalization, even if we never plan on doing it, because it forces you to understand the code at a logical level.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
|
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/pytorch/tensor/float8_tensor.py, line 557-565 (link)style: always creates new tensor even when already contiguous - consider early return optimization:
4 files reviewed, 1 comment
|
/te-ci pytorch |
Signed-off-by: Varun Thumbe <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, no comments
Signed-off-by: Varun Thumbe <[email protected]>
|
/te-ci pytorch |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 1 comment
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes #2422
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: