-
Notifications
You must be signed in to change notification settings - Fork 308
Add Float8BlockwiseLinear with Triton kernels for quantization and GEMMs #2592
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2592
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New FailuresAs of commit d0631c0 with merge base 0e00df3 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
d0cd3be
to
3b36022
Compare
error = torch.norm(C - C_q) / torch.norm(C) | ||
print(f"Relative Error: {error.item():.6f}") | ||
|
||
assert error < 0.1, "Quantize gemm error is too high" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use sqnr everywhere match w/ existing numerics testing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to use SQNR
|
||
# original implementation from fbgemm_gpu: | ||
# https://github.com/pytorch/FBGEMM/blob/b19401e913fcdff536dc097fa3013a0a9d66256e/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py#L3091 | ||
def triton_quantize_fp8_block( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since we have an optional runtime dependency on fbgemm can we just call their kernel directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that is the desired end state. For now I have tried and have had repeated problems getting it to work so far (fbgemm-gpu-genai
), e.g. undefined symbols. Tried on both H100 and B200 and got different undefined symbol errors
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
3b36022
to
9821453
Compare
this number is kinda weird to me, do you have memory bandwidth calcs? I dont immediately get why there is a 10x delta in group wise vs blockwise |
9821453
to
ee6ce03
Compare
ee6ce03
to
fa64d54
Compare
Yeah I agree it's odd, will try adding some mem bw calcs, was thinking about checking with Josh / fbgemm team as well if perhapst here is a different kernel they use for activation quant. |
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
fa64d54
to
9d1e13d
Compare
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
9d1e13d
to
766156b
Compare
deepgemm for GEMMs stack-info: PR: #2592, branch: danielvegamyhre/stack/15
766156b
to
41f63f6
Compare
…pgemm for GEMMs stack-info: PR: #2592, branch: danielvegamyhre/stack/15
41f63f6
to
b2c78e9
Compare
…pgemm for GEMMs stack-info: PR: #2592, branch: danielvegamyhre/stack/15
b2c78e9
to
b06f818
Compare
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
b06f818
to
1f06adc
Compare
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
1f06adc
to
48a0bb9
Compare
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
48a0bb9
to
0ed3a77
Compare
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
0ed3a77
to
77f2c8e
Compare
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
77f2c8e
to
05e1a19
Compare
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
05e1a19
to
343718a
Compare
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
343718a
to
97cfaa4
Compare
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
97cfaa4
to
44448c1
Compare
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
44448c1
to
a151e46
Compare
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
a151e46
to
0c2d688
Compare
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
0c2d688
to
d0631c0
Compare
Stacked PRs:
Summary
fp8_gemm
=>blockwise_fp8_gemm_1x128_128x128
).blockwise_fp8_gemm_1x128_1x128_kernel
for dW calculation where both left and right operands have activation scaling granularity (1 x block_size). This is a modified version of the kernel above, so it accepts 1x128 scaling for both operands.Why use Triton kernels instead of DeepGEMM cutlass kernels?
The GEMM APIs in @vkuzo's PoC here no longer exist in DeepGemm. I tried using the new GEMM APIs (
fp8_gemm_nt
etc), and:Since our only goal is a functional skeleton and not performance, rather than spend more time on this, I just used the existing Triton kernels we had and made a modified GEMM (1 line change) to support blockwise_fp8_gemm_1x128_1x128_kernel.
If we want to replace these Triton GEMMs with the Cutlass ones later to see if perf is better (it probably is), we can do that.
Note on numerics
Interestingly, the reference DeepGemm triton quantization kernels do NOT use EPS/clamping to prevent division by 0. This resulted in my unit tests passing (where inputs were from a normal distributed), but NaNs occuring in TorchTitan training runs, where actual activation values sometimes had amax of 0.
I updated the kernels to use the same EPS guards as torchao.float8, and this fixed the Nans.
Test plan
pytest test/prototype/blockwise_fp8/test_blockwise_kernels.py
pytest test/prototype/blockwise_fp8/test_blockwise_linear.py
Torchtitan PoC integration results