Skip to content

Conversation

@botbigeyes
Copy link
Contributor

PR Category
Operator

Type of Change
Bug Fix

Description
Fix the GroupNorm memory access bug for large tensors by using a loop-blocking method similar to the backpropagation process to handle the H×W dimensions, avoiding the numel exceeds issue.

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

benchmark/test_norm_perf.py::test_group_and_layer_and_instance_norm_benchmark[group_norm-group_norm-groupnorm_input_fn] 
Operator: group_norm  Performance Test (dtype=torch.float16, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.015424            0.029664               0.520          [torch.Size([4, 16, 64, 4]), 8, torch.Size([16]), torch.Size([16])]
SUCCESS               0.016800            0.006784               2.476          [torch.Size([4, 16, 64, 4]), 16, torch.Size([16]), torch.Size([16])]
SUCCESS               0.015072            0.007168               2.103          [torch.Size([16, 16, 8, 48]), 8, torch.Size([16]), torch.Size([16])]
SUCCESS               0.020064            0.007008               2.863          [torch.Size([16, 16, 8, 48]), 16, torch.Size([16]), torch.Size([16])]
SUCCESS               0.017312            0.007552               2.292          [torch.Size([16, 16, 8, 88]), 8, torch.Size([16]), torch.Size([16])]
SUCCESS               0.017216            0.007392               2.329          [torch.Size([16, 16, 8, 88]), 16, torch.Size([16]), torch.Size([16])]
SUCCESS               0.016480            0.007008               2.352          [torch.Size([16, 16, 128]), 8, torch.Size([16]), torch.Size([16])]
SUCCESS               0.016416            0.007008               2.342          [torch.Size([16, 16, 128]), 16, torch.Size([16]), torch.Size([16])]
SUCCESS               0.140320            0.072192               1.944          [torch.Size([20, 6, 65536]), 3, torch.Size([6]), torch.Size([6])]
SUCCESS               0.087584            0.040128               2.183          [torch.Size([20, 6, 65536]), 6, torch.Size([6]), torch.Size([6])]
SUCCESS               0.015008            0.006976               2.151          [torch.Size([16, 16, 64]), 8, torch.Size([16]), torch.Size([16])]
SUCCESS               0.015200            0.006880               2.209          [torch.Size([16, 16, 64]), 16, torch.Size([16]), torch.Size([16])]
SUCCESS               0.016352            0.007648               2.138          [torch.Size([16, 16, 1024]), 8, torch.Size([16]), torch.Size([16])]
SUCCESS               0.016544            0.007488               2.209          [torch.Size([16, 16, 1024]), 16, torch.Size([16]), torch.Size([16])]
SUCCESS               0.024608            0.012512               1.967          [torch.Size([16, 16, 4098]), 8, torch.Size([16]), torch.Size([16])]
SUCCESS               0.021664            0.011040               1.962          [torch.Size([16, 16, 4098]), 16, torch.Size([16]), torch.Size([16])]
SUCCESS               0.013312            0.006688               1.990          [torch.Size([1, 8, 4, 4]), 4, torch.Size([8]), torch.Size([8])]
SUCCESS               0.011936            0.006656               1.793          [torch.Size([1, 8, 4, 4]), 8, torch.Size([8]), torch.Size([8])]
SUCCESS               0.046048            0.018592               2.477          [torch.Size([16, 8, 128, 128]), 4, torch.Size([8]), torch.Size([8])]
SUCCESS               0.033824            0.015168               2.230          [torch.Size([16, 8, 128, 128]), 8, torch.Size([8]), torch.Size([8])]
Operator: group_norm  Performance Test (dtype=torch.float32, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.013472            0.006784               1.986          [torch.Size([4, 16, 64, 4]), 8, torch.Size([16]), torch.Size([16])]
SUCCESS               0.017152            0.006720               2.552          [torch.Size([4, 16, 64, 4]), 16, torch.Size([16]), torch.Size([16])]
SUCCESS               0.016256            0.007424               2.190          [torch.Size([16, 16, 8, 48]), 8, torch.Size([16]), torch.Size([16])]
SUCCESS               0.018624            0.007232               2.575          [torch.Size([16, 16, 8, 48]), 16, torch.Size([16]), torch.Size([16])]
SUCCESS               0.016000            0.007328               2.183          [torch.Size([16, 16, 8, 88]), 8, torch.Size([16]), torch.Size([16])]
SUCCESS               0.015168            0.007136               2.126          [torch.Size([16, 16, 8, 88]), 16, torch.Size([16]), torch.Size([16])]
SUCCESS               0.017600            0.006976               2.523          [torch.Size([16, 16, 128]), 8, torch.Size([16]), torch.Size([16])]
SUCCESS               0.014752            0.006784               2.175          [torch.Size([16, 16, 128]), 16, torch.Size([16]), torch.Size([16])]
SUCCESS               0.149680            0.130240               1.149          [torch.Size([20, 6, 65536]), 3, torch.Size([6]), torch.Size([6])]
SUCCESS               0.097984            0.101536               0.965          [torch.Size([20, 6, 65536]), 6, torch.Size([6]), torch.Size([6])]
SUCCESS               0.015744            0.006880               2.288          [torch.Size([16, 16, 64]), 8, torch.Size([16]), torch.Size([16])]
SUCCESS               0.014720            0.006752               2.180          [torch.Size([16, 16, 64]), 16, torch.Size([16]), torch.Size([16])]
SUCCESS               0.016768            0.007872               2.130          [torch.Size([16, 16, 1024]), 8, torch.Size([16]), torch.Size([16])]
SUCCESS               0.015456            0.007680               2.013          [torch.Size([16, 16, 1024]), 16, torch.Size([16]), torch.Size([16])]
SUCCESS               0.025216            0.013472               1.872          [torch.Size([16, 16, 4098]), 8, torch.Size([16]), torch.Size([16])]
SUCCESS               0.021184            0.024864               0.852          [torch.Size([16, 16, 4098]), 16, torch.Size([16]), torch.Size([16])]
SUCCESS               0.013216            0.007184               1.840          [torch.Size([1, 8, 4, 4]), 4, torch.Size([8]), torch.Size([8])]
SUCCESS               0.011648            0.006528               1.784          [torch.Size([1, 8, 4, 4]), 8, torch.Size([8]), torch.Size([8])]
SUCCESS               0.045952            0.021440               2.143          [torch.Size([16, 8, 128, 128]), 4, torch.Size([8]), torch.Size([8])]
SUCCESS               0.033184            0.018624               1.782          [torch.Size([16, 8, 128, 128]), 8, torch.Size([8]), torch.Size([8])]
Operator: group_norm  Performance Test (dtype=torch.bfloat16, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.014432            0.006880               2.098          [torch.Size([4, 16, 64, 4]), 8, torch.Size([16]), torch.Size([16])]
SUCCESS               0.017792            0.006784               2.623          [torch.Size([4, 16, 64, 4]), 16, torch.Size([16]), torch.Size([16])]
SUCCESS               0.015072            0.007200               2.093          [torch.Size([16, 16, 8, 48]), 8, torch.Size([16]), torch.Size([16])]
SUCCESS               0.019104            0.007072               2.701          [torch.Size([16, 16, 8, 48]), 16, torch.Size([16]), torch.Size([16])]
SUCCESS               0.016384            0.007552               2.169          [torch.Size([16, 16, 8, 88]), 8, torch.Size([16]), torch.Size([16])]
SUCCESS               0.017248            0.007456               2.313          [torch.Size([16, 16, 8, 88]), 16, torch.Size([16]), torch.Size([16])]
SUCCESS               0.018112            0.007008               2.584          [torch.Size([16, 16, 128]), 8, torch.Size([16]), torch.Size([16])]
SUCCESS               0.015392            0.007072               2.176          [torch.Size([16, 16, 128]), 16, torch.Size([16]), torch.Size([16])]
SUCCESS               0.139712            0.072128               1.937          [torch.Size([20, 6, 65536]), 3, torch.Size([6]), torch.Size([6])]
SUCCESS               0.088864            0.040320               2.204          [torch.Size([20, 6, 65536]), 6, torch.Size([6]), torch.Size([6])]
SUCCESS               0.016160            0.006976               2.317          [torch.Size([16, 16, 64]), 8, torch.Size([16]), torch.Size([16])]
SUCCESS               0.013632            0.006976               1.954          [torch.Size([16, 16, 64]), 16, torch.Size([16]), torch.Size([16])]
SUCCESS               0.017984            0.007584               2.371          [torch.Size([16, 16, 1024]), 8, torch.Size([16]), torch.Size([16])]
SUCCESS               0.016512            0.007552               2.186          [torch.Size([16, 16, 1024]), 16, torch.Size([16]), torch.Size([16])]
SUCCESS               0.023680            0.012416               1.907          [torch.Size([16, 16, 4098]), 8, torch.Size([16]), torch.Size([16])]
SUCCESS               0.020000            0.011136               1.796          [torch.Size([16, 16, 4098]), 16, torch.Size([16]), torch.Size([16])]
SUCCESS               0.012032            0.006752               1.782          [torch.Size([1, 8, 4, 4]), 4, torch.Size([8]), torch.Size([8])]
SUCCESS               0.012288            0.006560               1.873          [torch.Size([1, 8, 4, 4]), 8, torch.Size([8]), torch.Size([8])]
SUCCESS               0.046976            0.018752               2.505          [torch.Size([16, 8, 128, 128]), 4, torch.Size([8]), torch.Size([8])]
SUCCESS               0.032864            0.015264               2.153          [torch.Size([16, 8, 128, 128]), 8, torch.Size([8]), torch.Size([8])]

@CLAassistant
Copy link

CLAassistant commented Nov 14, 2025

CLA assistant check
All committers have signed the CLA.

@0x45f
Copy link
Collaborator

0x45f commented Nov 17, 2025

plz sign CLA

@botbigeyes
Copy link
Contributor Author

plz sign CLA

ok

eps,
BLOCK_GROUP_SIZE=triton.next_power_of_2(group_size),
BLOCK_HW_SIZE=triton.next_power_of_2(HxW),
BLOCK_HW_SIZE=1024,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use fixed BLOCK_HW_SIZE?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous bug was caused by BLOCK_HW_SIZE=triton.next_power_of_2(HxW), which made the index tensor xy_offset too large. Similar to the backward logic, we can fix BLOCK_HW_SIZE and introduce loop-blocking to handle large dimension sizes. This line can be deleted, as the default value is already used in the group_norm_kernel.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants