Skip to content

Conversation

Wangzheee
Copy link

@Wangzheee Wangzheee commented Sep 11, 2025

SwapAB: Significantly improve the performance for M%64<32

Description

  • Significantly improve the performance for BLOCK_M = 32 or M%64<32
  • Swap A B: WGMMA::wgmma(desc_b, desc_a, accum, k)
  • Use multi-math_warp_groups and multi-wave MMAs support BLOCK_N(64, 128, 256): BLOCK_N=256 for H20

How to use

  • export ENABLE_SWAPAB=1

Improvements (H20)

Aligned M, desired state: masked_m[j] = int(expected_m_per_group * random.uniform(1, 1))

  • Original

Perf (num_groups=16, expected_m_per_group= 32, n=4096, k=7168, 1D2D): 257 us | 117 TFLOPS | 1862 GB/s
Perf (num_groups=16, expected_m_per_group= 32, n=7168, k=2048, 1D2D): 127 us | 118 TFLOPS | 1914 GB/s
Perf (num_groups=16, expected_m_per_group= 96, n=4096, k=7168, 1D2D): 497 us | 181 TFLOPS | 993 GB/s
Perf (num_groups=16, expected_m_per_group= 96, n=7168, k=2048, 1D2D): 240 us | 188 TFLOPS | 1085 GB/s
Perf (num_groups=16, expected_m_per_group= 160, n=4096, k=7168, 1D2D): 921 us | 163 TFLOPS | 554 GB/s
Perf (num_groups=16, expected_m_per_group= 160, n=7168, k=2048, 1D2D): 472 us | 159 TFLOPS | 587 GB/s

  • SwapAB

Perf (num_groups=16, expected_m_per_group= 32, n=4096, k=7168, 1D2D): 154 us | 195 TFLOPS | 3100 GB/s
Perf (num_groups=16, expected_m_per_group= 32, n=7168, k=2048, 1D2D): 84 us | 180 TFLOPS | 2906 GB/s
Perf (num_groups=16, expected_m_per_group= 96, n=4096, k=7168, 1D2D): 340 us | 266 TFLOPS | 1454 GB/s
Perf (num_groups=16, expected_m_per_group= 96, n=7168, k=2048, 1D2D): 191 us | 236 TFLOPS | 1364 GB/s
Perf (num_groups=16, expected_m_per_group= 160, n=4096, k=7168, 1D2D): 572 us | 263 TFLOPS | 891 GB/s
Perf (num_groups=16, expected_m_per_group= 160, n=7168, k=2048, 1D2D): 303 us | 248 TFLOPS | 915 GB/s

Other case (original test): masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3))

  • Original

Perf (num_groups=1, expected_m_per_group=1024, n=4096, k=7168, 1D2D): 349 us | 214 TFLOPS | 141 GB/s
Perf (num_groups=1, expected_m_per_group=1024, n=7168, k=2048, 1D2D): 159 us | 215 TFLOPS | 213 GB/s
Perf (num_groups=2, expected_m_per_group= 512, n=4096, k=7168, 1D2D): 348 us | 178 TFLOPS | 216 GB/s
Perf (num_groups=2, expected_m_per_group= 512, n=7168, k=2048, 1D2D): 159 us | 191 TFLOPS | 292 GB/s
Perf (num_groups=4, expected_m_per_group= 256, n=4096, k=7168, 1D2D): 347 us | 174 TFLOPS | 384 GB/s
Perf (num_groups=4, expected_m_per_group= 256, n=7168, k=2048, 1D2D): 217 us | 146 TFLOPS | 353 GB/s
Perf (num_groups=16, expected_m_per_group= 64, n=4096, k=7168, 1D2D): 406 us | 153 TFLOPS | 1199 GB/s
Perf (num_groups=16, expected_m_per_group= 64, n=7168, k=2048, 1D2D): 172 us | 164 TFLOPS | 1455 GB/s
Perf (num_groups=16, expected_m_per_group= 128, n=4096, k=7168, 1D2D): 740 us | 165 TFLOPS | 678 GB/s
Perf (num_groups=16, expected_m_per_group= 128, n=7168, k=2048, 1D2D): 354 us | 168 TFLOPS | 758 GB/s

  • SwapAB

Perf (num_groups=1, expected_m_per_group=1024, n=4096, k=7168, 1D2D): 304 us | 246 TFLOPS | 162 GB/s
Perf (num_groups=1, expected_m_per_group=1024, n=7168, k=2048, 1D2D): 145 us | 235 TFLOPS | 233 GB/s
Perf (num_groups=2, expected_m_per_group= 512, n=4096, k=7168, 1D2D): 271 us | 229 TFLOPS | 278 GB/s
Perf (num_groups=2, expected_m_per_group= 512, n=7168, k=2048, 1D2D): 136 us | 224 TFLOPS | 342 GB/s
Perf (num_groups=4, expected_m_per_group= 256, n=4096, k=7168, 1D2D): 271 us | 223 TFLOPS | 493 GB/s
Perf (num_groups=4, expected_m_per_group= 256, n=7168, k=2048, 1D2D): 146 us | 217 TFLOPS | 523 GB/s
Perf (num_groups=16, expected_m_per_group= 64, n=4096, k=7168, 1D2D): 308 us | 201 TFLOPS | 1580 GB/s
Perf (num_groups=16, expected_m_per_group= 64, n=7168, k=2048, 1D2D): 161 us | 175 TFLOPS | 1556 GB/s
Perf (num_groups=16, expected_m_per_group= 128, n=4096, k=7168, 1D2D): 539 us | 227 TFLOPS | 931 GB/s
Perf (num_groups=16, expected_m_per_group= 128, n=7168, k=2048, 1D2D): 272 us | 218 TFLOPS | 986 GB/s

TODO

block_ns.push_back(i);
if(get_env<int>("ENABLE_SWAPAB")){
block_ms = std::vector{32}; // 32, 64
block_ns = std::vector{256}; // 64, 128, 256
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Manually set one of them. Experiments have found that in most cases, 256 performs the best

@LyricZhao
Copy link
Collaborator

Thanks! Merging it later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants