Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions src/liger_kernel/ops/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,21 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
_DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
_DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)

BLOCK_SIZE, num_warps = calculate_settings(n_cols)
if n_cols > BLOCK_SIZE:
raise RuntimeError(
f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
)
if X.device.type == "xpu": # XPU-specific optimization
BLOCK_SIZE = torch.xpu.get_device_properties(X.device).max_work_group_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

afaik we want the block size here to be triton.next_power_of_2(n). torch.xpu.get_device_properties(X.device).max_work_group_size is supposedly giving the max possible block size on the xpu device. Better way to handle this is to change the calculate_settings function directly where you set the MAX_FUSED_SIZE according to the device.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shivam15s These changes are specific for the layernorm and rmsnorm.That's why i made the changes locally and this will impact other kernels.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Tarakarevu1 how much is the perf improvment of layernorm and rmsnorm after this change?

else:
BLOCK_SIZE, num_warps = calculate_settings(n_cols)

if X.device.type == "xpu": # XPU-specific optimization
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Tarakarevu1 I think this is not an optimization but rather a guardrail to avoid spilling beyond 64 KB ?
LGTM on the code changes although the runtime behaviour error can be tweaked to show more xpu usage ?

if n_cols > 65536:
raise RuntimeError(
f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
)
else:
if n_cols > BLOCK_SIZE:
raise RuntimeError(
f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
)

rows_per_program = math.ceil(n_rows / sm_count)
grid = (sm_count,)
Expand All @@ -218,7 +228,7 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
# XPU-specific optimization
kernel_args = {}
if X.device.type == "xpu":
kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
kernel_args.update({"grf_mode": "large", "num_warps": 4, "num_stages": 4})

_layer_norm_backward_kernel[grid](
X,
Expand Down
16 changes: 13 additions & 3 deletions src/liger_kernel/ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,11 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
dim = shape[-1]
X = X.view(-1, dim)
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
if X.device.type == "xpu": # XPU-specific optimization
BLOCK_SIZE = torch.xpu.get_device_properties(X.device).max_work_group_size
num_warps = 4
else:
BLOCK_SIZE, num_warps = calculate_settings(n_cols)

Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
# RSTD is to cache rstd for each row
Expand Down Expand Up @@ -262,8 +266,14 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
# fp32 for numerical stability especially.
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)

if n_cols > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
if X.device.type == "xpu": # XPU-specific optimization
if n_cols > 65536:
raise RuntimeError(
"This layer norm doesn't support feature dim >= 64KB."
) # TODO RuntimeError might need little more investigation in the future
else:
if n_cols > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
rows_per_program = math.ceil(n_rows / sm_count)
grid = (sm_count,)

Expand Down
Loading