Skip to content

Conversation

Valentine233
Copy link
Collaborator

@Valentine233 Valentine233 commented Sep 12, 2025

Currently, we only check the MACRO definition CPUBLAS_BRGEMM_F8F8F32 at the API entrance (https://github.com/pytorch/ao/blob/main/torchao/csrc/cpu/aten_kernels/quantized_sdpa.cpp#L2533).

To avoid compiling issue with latest PyTorch (no CPUBLAS_BRGEMM_F8F8F32), this PR also adds checks outside the fp8 sdpa fused kernels.

Copy link

pytorch-bot bot commented Sep 12, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2991

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1a2f495 with merge base c4d4799 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 12, 2025
@Valentine233 Valentine233 added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Sep 12, 2025
@Valentine233
Copy link
Collaborator Author

@mingfeima @Xia-Weiwen @jerryzh168 Please kindly help review the PR~ Thanks!

@jerryzh168
Copy link
Contributor

jerryzh168 commented Sep 12, 2025

can you clarify what is the issue without the fix, e.g. error messages? and the result after fix?

@Valentine233
Copy link
Collaborator Author

@jerryzh168 Thanks for the review!

Some errors will be raised when compiling TorchAO using USE_CPU_KERNELS=1, with PyTorch main branch which has not yet defined CPUBLAS_BRGEMM_F8F8F32 to support FP8 BRGEMM.
Here are the compiling error messages without the fix:

14:05:26  In file included from torchao/csrc/cpu/aten_kernels/quantized_sdpa.cpp:10:
14:05:26  /home/pytorch/miniforge3/envs/2025ww7_torchao/lib/python3.10/site-packages/torch/include/ATen/native/CPUBlas.h:219:16: note: candidate: ‘void at::native::cpublas::brgemm(int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, bool, const c10::Half*, const c10::Half*, float*, bool)’
14:05:26    219 | TORCH_API void brgemm(
14:05:26        |                ^~~~~~
14:05:26  /home/pytorch/miniforge3/envs/2025ww7_torchao/lib/python3.10/site-packages/torch/include/ATen/native/CPUBlas.h:227:21: note:   no known conversion for argument 8 from ‘c10::Float8_e4m3fn*’ to ‘const c10::Half*’
14:05:26    227 |     const at::Half* A,
14:05:26        |     ~~~~~~~~~~~~~~~~^
14:05:26  /home/pytorch/miniforge3/envs/2025ww7_torchao/lib/python3.10/site-packages/torch/include/ATen/native/CPUBlas.h:232:16: note: candidate: ‘void at::native::cpublas::brgemm(int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, bool, const c10::BFloat16*, const c10::BFloat16*, float*, bool)’
14:05:26    232 | TORCH_API void brgemm(
14:05:26        |                ^~~~~~
14:05:26  /home/pytorch/miniforge3/envs/2025ww7_torchao/lib/python3.10/site-packages/torch/include/ATen/native/CPUBlas.h:240:25: note:   no known conversion for argument 8 from ‘c10::Float8_e4m3fn*’ to ‘const c10::BFloat16*’
14:05:26    240 |     const at::BFloat16* A,
14:05:26        |     ~~~~~~~~~~~~~~~~~~~~^
14:05:26  /home/pytorch/miniforge3/envs/2025ww7_torchao/lib/python3.10/site-packages/torch/include/ATen/native/CPUBlas.h:245:16: note: candidate: ‘void at::native::cpublas::brgemm(int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, bool, const float*, const float*, float*, bool)’
14:05:26    245 | TORCH_API void brgemm(
14:05:26        |                ^~~~~~
14:05:26  /home/pytorch/miniforge3/envs/2025ww7_torchao/lib/python3.10/site-packages/torch/include/ATen/native/CPUBlas.h:253:18: note:   no known conversion for argument 8 from ‘c10::Float8_e4m3fn*’ to ‘const float*’
14:05:26    253 |     const float* A,
14:05:26        |     ~~~~~~~~~~~~~^
14:05:26  /home/pytorch/miniforge3/envs/2025ww7_torchao/lib/python3.10/site-packages/torch/include/ATen/native/CPUBlas.h:258:16: note: candidate: ‘void at::native::cpublas::brgemm(int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, bool, const unsigned char*, const unsigned char*, int32_t*, bool)’
14:05:26    258 | TORCH_API void brgemm(
14:05:26        |                ^~~~~~
14:05:26  /home/pytorch/miniforge3/envs/2025ww7_torchao/lib/python3.10/site-packages/torch/include/ATen/native/CPUBlas.h:266:26: note:   no known conversion for argument 8 from ‘c10::Float8_e4m3fn*’ to ‘const unsigned char*’
14:05:26    266 |     const unsigned char* A,
14:05:26        |     ~~~~~~~~~~~~~~~~~~~~~^
14:05:26  /home/pytorch/miniforge3/envs/2025ww7_torchao/lib/python3.10/site-packages/torch/include/ATen/native/CPUBlas.h:271:16: note: candidate: ‘void at::native::cpublas::brgemm(int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, bool, const unsigned char*, const signed char*, int32_t*, bool)’
14:05:26    271 | TORCH_API void brgemm(
14:05:26        |                ^~~~~~
14:05:26  /home/pytorch/miniforge3/envs/2025ww7_torchao/lib/python3.10/site-packages/torch/include/ATen/native/CPUBlas.h:279:26: note:   no known conversion for argument 8 from ‘c10::Float8_e4m3fn*’ to ‘const unsigned char*’
14:05:26    279 |     const unsigned char* A,
14:05:26        |     ~~~~~~~~~~~~~~~~~~~~~^
14:05:26  /home/pytorch/miniforge3/envs/2025ww7_torchao/lib/python3.10/site-packages/torch/include/ATen/native/CPUBlas.h:284:16: note: candidate: ‘void at::native::cpublas::brgemm(int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, bool, const signed char*, const signed char*, int32_t*, bool)’
14:05:26    284 | TORCH_API void brgemm(
14:05:26        |                ^~~~~~
14:05:26  /home/pytorch/miniforge3/envs/2025ww7_torchao/lib/python3.10/site-packages/torch/include/ATen/native/CPUBlas.h:292:24: note:   no known conversion for argument 8 from ‘c10::Float8_e4m3fn*’ to ‘const signed char*’
14:05:26    292 |     const signed char* A,
14:05:26        |     ~~~~~~~~~~~~~~~~~~~^
14:05:26  torchao/csrc/cpu/aten_kernels/quantized_sdpa.cpp: In instantiation of ‘std::enable_if_t<is_same_v<scalar_t, c10::Float8_e4m3fn>, void> torchao::{anonymous}::fp8_sdpa_fused_kernel_impl(const at::Tensor&, const at::Tensor&, const at::Tensor&, const at::Tensor&, double, bool, std::optional<at::Tensor>, std::optional<double>, float, float, float, float, float) [with scalar_t = c10::Float8_e4m3fn; mask_t = double; long int q_split_size = 64; long int kv_split_size = 64; std::enable_if_t<is_same_v<scalar_t, c10::Float8_e4m3fn>, void> = void]’:
14:05:26  torchao/csrc/cpu/aten_kernels/quantized_sdpa.cpp:2357:5:   required from here
14:05:26  torchao/csrc/cpu/aten_kernels/quantized_sdpa.cpp:2000:36: error: no matching function for call to ‘brgemm(int64_t&, int64_t&, int64_t&, int64_t&, int64_t&, int64_t&, bool, const c10::Float8_e4m3fn*, c10::Float8_e4m3fn*, accum_t*&)’
14:05:26   2000 |         at::native::cpublas::brgemm(
14:05:26        |         ~~~~~~~~~~~~~~~~~~~~~~~~~~~^
14:05:26   2001 |           qBlockSize,
14:05:26        |           ~~~~~~~~~~~               
14:05:26   2002 |           kvBlockSize,
14:05:26        |           ~~~~~~~~~~~~              
14:05:26   2003 |           eheadSize,
14:05:26        |           ~~~~~~~~~~                
14:05:26   2004 |           headSize_even ? qStrideM : eheadSize,
14:05:26        |           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
14:05:26   2005 |           kvBlockSize,
14:05:26        |           ~~~~~~~~~~~~              
14:05:26   2006 |           kvBlockSize,
14:05:26        |           ~~~~~~~~~~~~              
14:05:26   2007 |           false,
14:05:26        |           ~~~~~~                    
14:05:26   2008 |           !headSize_even
14:05:26        |           ~~~~~~~~~~~~~~            
14:05:26   2009 |               ? query_t_padding_ptr
14:05:26        |               ~~~~~~~~~~~~~~~~~~~~~ 
14:05:26   2010 |               : q_data + i * qStrideB + j * qStrideH + m * qStrideM,
14:05:26        |               ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
14:05:26   2011 |           key_reorder_ptr + i * num_head * eheadSize * kvSize +
14:05:26        |           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
14:05:26   2012 |               j * eheadSize * kvSize + n * eheadSize,
14:05:26        |               ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
14:05:26   2013 |           qk_data);
14:05:26        |           ~~~~~~~~                  

With this fix, the compilation can be successfully done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants