-
Notifications
You must be signed in to change notification settings - Fork 517
FSDP grad fusion support #2191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
FSDP grad fusion support #2191
Conversation
Signed-off-by: Selvaraj Anandaraj <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this makes sense. If you configure a TE module with fuse_wgrad_accumulation=True
(e.g. here), the correct behavior is to fuse wgrad accumulation. If Mcore FSDP doesn't support it, then it should be Mcore's responsibility to not set that arg.
The root problem is that Mcore DDP and FSDP have different behaviors and require different contracts with TE:
I don't like this PR's approach of switching between these two cases based on whether Mcore is using DDP or FSDP, since that's not actually the important thing. It also needlessly blocks some possible optimizations (DDP might want to overwrite There are a few possible redesigns:
grad_weight: torch.Tensor
accumulate: bool = False
if output_wgrad_to_main_grad:
if getattr(weight, "get_main_grad", None) is not None:
grad_weight = weight.get_main_grad()
else:
grad_weight = weight.main_grad
accumulate = getattr(weight, "_overwrite_main_grad", True)
else:
grad_weight = torch.empty(...)
gemm(..., out=grad_weight, accumulate=accumulate) Ensuring backward compatibility will be tricky.
|
Signed-off-by: Selvaraj Anandaraj <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should include this behavior in the documentation:
fuse_wgrad_accumulation : bool, default = 'False' |
fuse_wgrad_accumulation : bool, default = 'False' |
fuse_wgrad_accumulation : bool, default = 'False' |
fuse_wgrad_accumulation : bool, default = 'False' |
accumulate_into_main_grad: bool, default = `False` |
accumulate_into_main_grad: bool, default = `False` |
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
Outdated
Show resolved
Hide resolved
Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Selvaraj Anandaraj <[email protected]>
Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Selvaraj Anandaraj <[email protected]>
Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Selvaraj Anandaraj <[email protected]>
…ar.py Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
for more information, see https://pre-commit.ci
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, pending CI.
It seems that DCO didn't like some commits, but they look fine to me. Maybe there's something misconfigured with your GitHub account's emails or maybe DCO is just buggy? In any case, I'm happy leaving this PR as-is and overriding DCO.
This PR adds support gradient fusion for MCore FSDP.