【Triton Copilot】Fix rms_norm backward precision issue #1084
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR fixes the numerical precision issue in
rms_normbackward implementation reported in #1068.Development Tool:
Problem
The backward pass of
rms_normhad a precision issue where:x.gradcontained NaN/Inf valuesRoot Cause
The issue was caused by not ensuring
dyis contiguous before passing it to the kernel. Whendyis not contiguous, the stride parameters passed to the kernel are incorrect, causing the kernel to read from wrong memory locations.Solution
dy = dy.contiguous()in the backward function to ensure correct stride parameterspartial_bufferinitialization fromtorch.empty()totorch.zeros()for better robustnessChanges
dy = dy.contiguous()inRmsNorm.backward()(line 186)partial_buffer = torch.empty(...)topartial_buffer = torch.zeros(...)(line 200)Verification
Test Results
Using the same test data from issue #1068:
Fixes #1068