-
Notifications
You must be signed in to change notification settings - Fork 423
feat(ce,flce): decouple gradients computation for no_grad mode #894
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
291141c to
9ab603a
Compare
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks for adding this. The forward flce still is significantly slower than hf since we're still computing grad_input applying token scaling logic? Also do you know why fp32 accum is faster?
|
@shimizust The slower forward pass is kinda expected because:
I just found that we can remove this line in eval mode as well, cutting another matmul for each interation should be significant.
and some grad tensors allocations too |
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> follow-up #894 <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Tcc0403 <[email protected]>
Summary
Add a flag
HAS_GRADIENTSto cross entropy kernel. No more gradients computation if there's no need.Testing Done
Cross Entropy forward with no_grad

Fused Linear Cross Entropy forward with no_grad

make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence