-
Notifications
You must be signed in to change notification settings - Fork 426
Modified block and warp sizes for improved performance on XPU for both layernnorm and rmsnorm #661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Tarakarevu1
wants to merge
5
commits into
linkedin:main
Choose a base branch
from
Tarakarevu1:Tarakarevu1-patch-3
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 3 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -197,11 +197,21 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD): | |
| _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device) | ||
| _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device) | ||
|
|
||
| BLOCK_SIZE, num_warps = calculate_settings(n_cols) | ||
| if n_cols > BLOCK_SIZE: | ||
| raise RuntimeError( | ||
| f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension." | ||
| ) | ||
| if X.device.type == "xpu": # XPU-specific optimization | ||
| BLOCK_SIZE = torch.xpu.get_device_properties(X.device).max_work_group_size | ||
| else: | ||
| BLOCK_SIZE, num_warps = calculate_settings(n_cols) | ||
|
|
||
| if X.device.type == "xpu": # XPU-specific optimization | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Tarakarevu1 I think this is not an optimization but rather a guardrail to avoid spilling beyond 64 KB ? |
||
| if n_cols > 65536: | ||
| raise RuntimeError( | ||
| f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension." | ||
| ) | ||
| else: | ||
| if n_cols > BLOCK_SIZE: | ||
| raise RuntimeError( | ||
| f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension." | ||
| ) | ||
|
|
||
| rows_per_program = math.ceil(n_rows / sm_count) | ||
| grid = (sm_count,) | ||
|
|
@@ -218,7 +228,7 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD): | |
| # XPU-specific optimization | ||
| kernel_args = {} | ||
| if X.device.type == "xpu": | ||
| kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4}) | ||
| kernel_args.update({"grf_mode": "large", "num_warps": 4, "num_stages": 4}) | ||
|
|
||
| _layer_norm_backward_kernel[grid]( | ||
| X, | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
afaik we want the block size here to be
triton.next_power_of_2(n).torch.xpu.get_device_properties(X.device).max_work_group_sizeis supposedly giving the max possible block size on the xpu device. Better way to handle this is to change the calculate_settings function directly where you set theMAX_FUSED_SIZEaccording to the device.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@shivam15s These changes are specific for the layernorm and rmsnorm.That's why i made the changes locally and this will impact other kernels.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Tarakarevu1 how much is the perf improvment of layernorm and rmsnorm after this change?