-
Notifications
You must be signed in to change notification settings - Fork 379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Support TP Overlap in Per-Tensor Current Scaling Recipe #1554
Conversation
opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); | ||
at::Tensor scale_inv = at::empty(scale_inv_torch_shape, opts); | ||
// In current scaling, scale is not known but we initialize it with 1 to avoid division by zero. If scale is already calculated, it can be correctly set. | ||
at::Tensor scale_inv = at::reciprocal(scale); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In principle we don't know if the value of scale
is valid at this point. It is a workspace buffer for an intermediate value during quantization, and we only expose it outside quantization to make debugging more convenient: #1471 (comment)
The actual bug is that the UB comm doesn't handle scales properly in copy_into_buffer
and get_buffer
. This is an indirect fix that assumes that UB is constructing the gathered tensor after it has called the quantizer. This might not always be true: pipeline parallelism involves multiple forward passes and then multiple backward passes, so input_quantizer
might have been called multiple times in between a forward and backward pass.
I'm working on a more general bugfix, so I'm fine accepting this as a temporary fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
UB is initialized without quantization-related info. I think that is the reason why copy_into_buffer
does not handle the scale part. Then get_buffer
will use the scale information from quantizer, which will duplicate the calculation of scale_inv
. Agreed that the current change here is indeed a temporary fix.
# reduce duplicated transpose in `_fix_gathered_fp8_transpose` | ||
input_quantizer.set_usage(rowwise=True, columnwise=False) | ||
ub_obj_fprop = get_ub(ub_name + "_fprop") | ||
ln_out_tmp = ln_out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we should give it a better name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use ln_out_local
instead, which should be more accurate.
) | ||
if fp8 and ub_bulk_dgrad: | ||
# reduce duplicated transpose in `_fix_gathered_fp8_transpose` | ||
input_quantizer.set_usage(rowwise=True, columnwise=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This optimization can also apply to function grad_output_preprocess in base.py, in which we did grad_output = quantizer(grad_output).
Since if we want to do allgather of grad_output under sequence + row parallel, we expect UB to do the allgather of grad_output (or dY), then we should also have an set_usage(rowwise=True, columnwise=False) before it. Otherwise we will also be wasting one at::empty() for malloc and one cast_transpose which could have been just cast.
The actual transpose of grad_output happens in linear/layernormlinear after we call the ub get_buffer.
if isinstance(grad_output, QuantizedTensor):
# This is a no-op if platform supports non-TN FP8 GEMM or the transpose
# already exists.
grad_output.update_usage(rowwise_usage=True, columnwise_usage=True)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is one thing that I discussed offline with @timmoon10 about this issue when we one day might add TPGEMM overlap to MXFP8.
For MXFP8, the scaling granularity is 1x32, therefore the "transpose" isn't just the data transpose, but also the scaling granularity change to 32x1. So, we might need to do two allgather ops in UB to overlap with GEMM.
In this case, we still need to set_usage(rowwise=True, columnwise=True) to make sure we have two copies of data to communicate.
So instead of doing if ub_bulk_dgrad:, maybe we should do
if ub_bulk_dgrad and is_per_tensor_recipe():
The is_per_tensor_scaling covers both delayed scaling and per-tensor current scaling. Both recipe should benefit from this optimization.
if ub_overlap_ag_fprop and isinstance(input_quantizer, Float8CurrentScalingQuantizer): | ||
if ub_bulk_dgrad: | ||
# reduce duplicated transpose in `_fix_gathered_fp8_transpose` | ||
input_quantizer.set_usage(rowwise=True, columnwise=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why should we call set_usage(rowwise=True, columnwise=False)
twice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The set_usage
will be indeed called several times. Currently, I disable duplicated transpose processing by another set_usage
call. Otherwise, we need to add such expressions in a lot of conditions for each possible case. Not sure the overhead for multiple calling set_usage
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM
/te-ci pytorch L1 |
ub_obj_fprop = None | ||
ln_out = None | ||
if ub_overlap_ag_fprop: | ||
if ub_overlap_ag_fprop and not isinstance(input_quantizer, Float8CurrentScalingQuantizer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should leave a comment here to explain why we skip current scaling quantizer.
The reason I guess should be: for current scaling, we don't want layernorm to output fp8 but rather bf16, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have added comments to address this.
) | ||
if ub_bulk_dgrad: | ||
# reduce duplicated transpose in `_fix_gathered_fp8_transpose` | ||
input_quantizer.set_usage(rowwise=True, columnwise=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should also add a per_tensor_scaling check here in the if condition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -424,6 +429,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], | |||
ub_obj_dgrad = ctx.ub_obj_gradout | |||
ub_type_dgrad = tex.CommOverlapType.AG | |||
ub_obj_dgrad.copy_into_buffer(inputmat, ctx.input_quantizer, local_chunk=True) | |||
inputmat = ub_obj_dgrad.get_buffer(ctx.input_quantizer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need this line?
We are using inputmat_total for wgrad calculation, and user buffer will handle the allgather of inputmat to get input_total, so we already have a line like below that fetches the gathered inputmat_total using ub get_buffer.
wgrad = None
if ctx.requires_wgrad:
if ctx.ub_bulk_dgrad:
inputmat_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer)
If this line is not necessary then we should remove it because get_buffer is not a free op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. Removing this line does not affect any results. Have deleted this line.
Nit in description section
should be
|
Done some tests myself, LGTM |
ub_obj_fprop = get_ub(ub_name + "_fprop") | ||
ln_out_local = ln_out | ||
ln_out = ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True) | ||
input_quantizer.quantize(ln_out_local, out=ln_out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, I think I still have questions about this line, and we need to figure this out before we merge.
This line (plus another one in layernorm_mlp) are the only places that you call get_buffer
first and then use quantizer.quantize(inp, out)
to do the quantize (ie. in-place quiantize without the need to malloc output tensor). This pattern will imply that the output tensor's dataptr is managed by user buffer system.
It's nice that this might save us some time copying data around, but I am worried that in actual E2E training, is it okay to save such tensor for backward if the dataptr belongs to ubuf? Shouldn't the data be compromised if we move to the next layer.
Unfortunately, our layer level unit test cannot catch it so I am not sure how can I prove my point. @timmoon10 what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, I think this is a bug from the 2.0 release. Before that, the GEMM function had an option to copy values from the UB buffer into a plain PyTorch tensor:
extra_output_tensor=ln_out if ub_overlap_ag else None, |
We didn't handle this when updating to 2.0, and I agree that it seems likely that
ln_out
is corrupted in between the forward and backward pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The key point should be that we use get_buffer
and then ln_out
is a ub buffer which will be saved and restored in training. All transformer layer shares the same ub so the values would change.
I think it can be replaced by using copy_into_buffer
, where ln_out
is a normal tensor. This makes sure that ub has the correct values and the saved tensor in prepare_for_saving
is not a ub buffer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
ub_obj_lnout = get_ub("fc1_fprop") | ||
ln_out_local = ln_out | ||
ln_out = ub_obj_lnout.get_buffer(fc1_input_quantizer, local_chunk=True) | ||
fc1_input_quantizer.quantize(ln_out_local, out=ln_out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And here is the other one.
if ub_overlap_ag_fprop: | ||
# For DelayScaling, output of normalization will be in fp8. | ||
# For Float8CurrentScaling, we want the output of normalization in high precision, then quantize to fp8. | ||
if ub_overlap_ag_fprop and not fp8: | ||
ub_obj_fprop = get_ub(ub_name + "_fprop") | ||
ln_out = ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we making a difference between bf16 case and fp8 case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because for bf16 case, ub can be still used so I do not change the logic here.
The main fix is about that the return value is changed to ln_out.clone() if ub_overlap_ag_fprop else ln_out
.
ub_obj_fprop = get_ub(ub_name + "_fprop") | ||
ln_out = ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True) | ||
elif ub_overlap_ag_fprop and not isinstance(input_quantizer, Float8CurrentScalingQuantizer): | ||
ln_out = input_quantizer.make_empty(inputmat.shape, dtype=inputmat.dtype, device="cuda") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This elif used to belong to fp8 delayed scaling right? For delayed scaling, we still want layernorm to directly output to an fp8 buffer, only difference is whether this fp8 buffer is allocated by pytorch or ubuf.
Here if we just allocate with inputmat.dtype, are we wasting some memory? Since we were only supposed to allocate the size enough for a fp8 tensor.
/te-ci pytorch L1 |
Cannot find any more issues, guess we are good to go? @timmoon10 |
/te-ci pytorch L1 |
Signed-off-by: Li Tao <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Li Tao <[email protected]>
Signed-off-by: Li Tao <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Li Tao <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Li Tao <[email protected]>
Signed-off-by: Li Tao <[email protected]>
Signed-off-by: Li Tao <[email protected]>
Signed-off-by: Li Tao <[email protected]>
Signed-off-by: Li Tao <[email protected]>
Signed-off-by: Li Tao <[email protected]>
Signed-off-by: Li Tao <[email protected]>
Signed-off-by: Li Tao <[email protected]>
Signed-off-by: Li Tao <[email protected]>
Signed-off-by: Li Tao <[email protected]>
Signed-off-by: Li Tao <[email protected]>
… limit max gpus for test Signed-off-by: zhongboz <[email protected]>
673ee32
to
f686a49
Compare
for more information, see https://pre-commit.ci
/te-ci pytorch L1 |
This fixes some This PR has a critical bugfix (see #1554 (comment)), so I think it's better to merge quickly and fiddle with the test tolerances later. |
Description
Enable TP/Comm Overlap for Per-Tensor Current Scaling Recipe.
Type of change
Changes
scale_inv
value)Unit Tests
Python Unit Tests
Checklist: