Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Support TP Overlap in Per-Tensor Current Scaling Recipe #1554

Merged
merged 20 commits into from
Mar 15, 2025

Conversation

BestJuly
Copy link
Contributor

@BestJuly BestJuly commented Mar 10, 2025

Description

Enable TP/Comm Overlap for Per-Tensor Current Scaling Recipe.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Unit test related code changes
  • Module related code changes (i.e., linear, layernorm_linear, layernorm_mlp)
  • Quantization related code changes (help to use the correct scale_inv value)

Unit Tests

Python Unit Tests

# Test Current Scaling only
pytest -s tests/pytorch/distributed/test_comm_gemm_overlap.py -k "FP8 and CURRENT" -v

# Test all overlapping cases
pytest -s tests/pytorch/distributed/test_comm_gemm_overlap.py -v

# Particular pattern test, take `LayernormLinear` module as an example
# Set UB_SKIPMC=1 if necessary
torchrun --nproc_per_node 2 tests/pytorch/distributed/run_layer_with_overlap.py --seq-length=1024 --batch-size=1 --num-heads=8 --head-dim=64 --layer-type=layernormlinear --fp8 --quantization=fp8_current_scaling

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

opts = opts.dtype(torch::kFloat32).device(torch::kCUDA);
at::Tensor scale_inv = at::empty(scale_inv_torch_shape, opts);
// In current scaling, scale is not known but we initialize it with 1 to avoid division by zero. If scale is already calculated, it can be correctly set.
at::Tensor scale_inv = at::reciprocal(scale);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle we don't know if the value of scale is valid at this point. It is a workspace buffer for an intermediate value during quantization, and we only expose it outside quantization to make debugging more convenient: #1471 (comment)

The actual bug is that the UB comm doesn't handle scales properly in copy_into_buffer and get_buffer. This is an indirect fix that assumes that UB is constructing the gathered tensor after it has called the quantizer. This might not always be true: pipeline parallelism involves multiple forward passes and then multiple backward passes, so input_quantizer might have been called multiple times in between a forward and backward pass.

I'm working on a more general bugfix, so I'm fine accepting this as a temporary fix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UB is initialized without quantization-related info. I think that is the reason why copy_into_buffer does not handle the scale part. Then get_buffer will use the scale information from quantizer, which will duplicate the calculation of scale_inv. Agreed that the current change here is indeed a temporary fix.

# reduce duplicated transpose in `_fix_gathered_fp8_transpose`
input_quantizer.set_usage(rowwise=True, columnwise=False)
ub_obj_fprop = get_ub(ub_name + "_fprop")
ln_out_tmp = ln_out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should give it a better name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use ln_out_local instead, which should be more accurate.

)
if fp8 and ub_bulk_dgrad:
# reduce duplicated transpose in `_fix_gathered_fp8_transpose`
input_quantizer.set_usage(rowwise=True, columnwise=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This optimization can also apply to function grad_output_preprocess in base.py, in which we did grad_output = quantizer(grad_output).

Since if we want to do allgather of grad_output under sequence + row parallel, we expect UB to do the allgather of grad_output (or dY), then we should also have an set_usage(rowwise=True, columnwise=False) before it. Otherwise we will also be wasting one at::empty() for malloc and one cast_transpose which could have been just cast.

The actual transpose of grad_output happens in linear/layernormlinear after we call the ub get_buffer.

            if isinstance(grad_output, QuantizedTensor):
                # This is a no-op if platform supports non-TN FP8 GEMM or the transpose
                # already exists.
                grad_output.update_usage(rowwise_usage=True, columnwise_usage=True)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is one thing that I discussed offline with @timmoon10 about this issue when we one day might add TPGEMM overlap to MXFP8.

For MXFP8, the scaling granularity is 1x32, therefore the "transpose" isn't just the data transpose, but also the scaling granularity change to 32x1. So, we might need to do two allgather ops in UB to overlap with GEMM.

In this case, we still need to set_usage(rowwise=True, columnwise=True) to make sure we have two copies of data to communicate.

So instead of doing if ub_bulk_dgrad:, maybe we should do

if ub_bulk_dgrad and is_per_tensor_recipe():
The is_per_tensor_scaling covers both delayed scaling and per-tensor current scaling. Both recipe should benefit from this optimization.

if ub_overlap_ag_fprop and isinstance(input_quantizer, Float8CurrentScalingQuantizer):
if ub_bulk_dgrad:
# reduce duplicated transpose in `_fix_gathered_fp8_transpose`
input_quantizer.set_usage(rowwise=True, columnwise=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why should we call set_usage(rowwise=True, columnwise=False) twice?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The set_usage will be indeed called several times. Currently, I disable duplicated transpose processing by another set_usage call. Otherwise, we need to add such expressions in a lot of conditions for each possible case. Not sure the overhead for multiple calling set_usage.

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM

@timmoon10
Copy link
Collaborator

/te-ci pytorch L1

ub_obj_fprop = None
ln_out = None
if ub_overlap_ag_fprop:
if ub_overlap_ag_fprop and not isinstance(input_quantizer, Float8CurrentScalingQuantizer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should leave a comment here to explain why we skip current scaling quantizer.

The reason I guess should be: for current scaling, we don't want layernorm to output fp8 but rather bf16, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have added comments to address this.

)
if ub_bulk_dgrad:
# reduce duplicated transpose in `_fix_gathered_fp8_transpose`
input_quantizer.set_usage(rowwise=True, columnwise=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should also add a per_tensor_scaling check here in the if condition?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -424,6 +429,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG
ub_obj_dgrad.copy_into_buffer(inputmat, ctx.input_quantizer, local_chunk=True)
inputmat = ub_obj_dgrad.get_buffer(ctx.input_quantizer)
Copy link
Contributor

@zhongbozhu zhongbozhu Mar 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this line?

We are using inputmat_total for wgrad calculation, and user buffer will handle the allgather of inputmat to get input_total, so we already have a line like below that fetches the gathered inputmat_total using ub get_buffer.

            wgrad = None
            if ctx.requires_wgrad:
                if ctx.ub_bulk_dgrad:
                    inputmat_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer)

If this line is not necessary then we should remove it because get_buffer is not a free op.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Removing this line does not affect any results. Have deleted this line.

@zhongbozhu
Copy link
Contributor

zhongbozhu commented Mar 11, 2025

Nit in description section

# Particular pattern test, take `LayernormLinear` module as an example
# Set UB_SKIPMC=1 if necessary
torchrun --nproc_per_node 2 tests/pytorch/distributed/run_layer_with_overlap.py --seq-length=1024 --batch-size=1 --num-heads=8 --head-dim=64 --layer-type=layernormlinear --fp8 --fp8-recipe=tensorwise

should be --quantization=fp8_current_scaling

# Set UB_SKIPMC=1 if necessary
torchrun --nproc_per_node 2 tests/pytorch/distributed/run_layer_with_overlap.py --seq-length=1024 --batch-size=1 --num-heads=8 --head-dim=64 --layer-type=layernormlinear --fp8 --quantization=fp8_current_scaling

@zhongbozhu
Copy link
Contributor

Done some tests myself, LGTM

ub_obj_fprop = get_ub(ub_name + "_fprop")
ln_out_local = ln_out
ln_out = ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True)
input_quantizer.quantize(ln_out_local, out=ln_out)
Copy link
Contributor

@zhongbozhu zhongbozhu Mar 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I think I still have questions about this line, and we need to figure this out before we merge.

This line (plus another one in layernorm_mlp) are the only places that you call get_buffer first and then use quantizer.quantize(inp, out) to do the quantize (ie. in-place quiantize without the need to malloc output tensor). This pattern will imply that the output tensor's dataptr is managed by user buffer system.

It's nice that this might save us some time copying data around, but I am worried that in actual E2E training, is it okay to save such tensor for backward if the dataptr belongs to ubuf? Shouldn't the data be compromised if we move to the next layer.

Unfortunately, our layer level unit test cannot catch it so I am not sure how can I prove my point. @timmoon10 what do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I think this is a bug from the 2.0 release. Before that, the GEMM function had an option to copy values from the UB buffer into a plain PyTorch tensor:

extra_output_tensor=ln_out if ub_overlap_ag else None,

We didn't handle this when updating to 2.0, and I agree that it seems likely that ln_out is corrupted in between the forward and backward pass.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The key point should be that we use get_buffer and then ln_out is a ub buffer which will be saved and restored in training. All transformer layer shares the same ub so the values would change.
I think it can be replaced by using copy_into_buffer, where ln_out is a normal tensor. This makes sure that ub has the correct values and the saved tensor in prepare_for_saving is not a ub buffer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

ub_obj_lnout = get_ub("fc1_fprop")
ln_out_local = ln_out
ln_out = ub_obj_lnout.get_buffer(fc1_input_quantizer, local_chunk=True)
fc1_input_quantizer.quantize(ln_out_local, out=ln_out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here is the other one.

if ub_overlap_ag_fprop:
# For DelayScaling, output of normalization will be in fp8.
# For Float8CurrentScaling, we want the output of normalization in high precision, then quantize to fp8.
if ub_overlap_ag_fprop and not fp8:
ub_obj_fprop = get_ub(ub_name + "_fprop")
ln_out = ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we making a difference between bf16 case and fp8 case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because for bf16 case, ub can be still used so I do not change the logic here.
The main fix is about that the return value is changed to ln_out.clone() if ub_overlap_ag_fprop else ln_out.

ub_obj_fprop = get_ub(ub_name + "_fprop")
ln_out = ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True)
elif ub_overlap_ag_fprop and not isinstance(input_quantizer, Float8CurrentScalingQuantizer):
ln_out = input_quantizer.make_empty(inputmat.shape, dtype=inputmat.dtype, device="cuda")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This elif used to belong to fp8 delayed scaling right? For delayed scaling, we still want layernorm to directly output to an fp8 buffer, only difference is whether this fp8 buffer is allocated by pytorch or ubuf.

Here if we just allocate with inputmat.dtype, are we wasting some memory? Since we were only supposed to allocate the size enough for a fp8 tensor.

@timmoon10
Copy link
Collaborator

/te-ci pytorch L1

@zhongbozhu
Copy link
Contributor

Cannot find any more issues, guess we are good to go? @timmoon10

@timmoon10
Copy link
Collaborator

/te-ci pytorch L1

@zhongbozhu zhongbozhu force-pushed the lit/enable_cs_tp_overlap branch from 673ee32 to f686a49 Compare March 14, 2025 22:54
@timmoon10
Copy link
Collaborator

/te-ci pytorch L1

@timmoon10
Copy link
Collaborator

This fixes some TransformerLayer UB test failures (by capping to TP=4), but it introduces some other TransformerLayer UB test failures (numerical error slightly beyond bounds with FP8 current scaling).

This PR has a critical bugfix (see #1554 (comment)), so I think it's better to merge quickly and fiddle with the test tolerances later.

@timmoon10 timmoon10 merged commit a7eeb28 into NVIDIA:main Mar 15, 2025
11 of 12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants