Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion benchmark/scripts/benchmark_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def fwd():

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES)
elif mode == "no-grad-forward":
with torch.no_grad():
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES)
elif mode == "backward":
y = fwd()

Expand Down Expand Up @@ -109,7 +112,7 @@ def full():

run_benchmarks(
bench_test_fn=bench_speed_cross_entropy,
kernel_operation_modes=["forward", "backward", "full"],
kernel_operation_modes=["forward", "backward", "full", "no-grad-forward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
Expand Down
44 changes: 25 additions & 19 deletions benchmark/scripts/benchmark_fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,26 +59,26 @@ def bench_memory_fused_linear_cross_entropy(
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_ce_fp32_accum = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device)
lm_head_ce = None
if provider == "liger":
lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)
elif provider == "liger-fp32-accum":
lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device)
else:
lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)

_input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1)

def fwd():
if provider == "liger":
return liger_lm_head_ce(_input, target)
elif provider == "liger-fp32-accum":
return liger_lm_head_ce_fp32_accum(_input, target)
elif provider == "huggingface":
return torch_lm_head_ce(_input, target)
return lm_head_ce(_input, target)

def full():
y = fwd()
y.backward()

mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)

return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
Expand All @@ -101,27 +101,33 @@ def bench_speed_fused_linear_cross_entropy(
provider = input.kernel_provider
mode = input.kernel_operation_mode

torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_ce_fp32_accum = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device)
lm_head_ce = None
if provider == "liger":
lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)
elif provider == "liger-fp32-accum":
lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device)
else:
lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)

_input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1)

def fwd():
if provider == "liger":
return liger_lm_head_ce(_input, target)
elif provider == "liger-fp32-accum":
return liger_lm_head_ce_fp32_accum(_input, target)
elif provider == "huggingface":
return torch_lm_head_ce(_input, target)
return lm_head_ce(_input, target)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "no-grad-forward":
with torch.no_grad():
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()

Expand Down Expand Up @@ -164,7 +170,7 @@ def full():

run_benchmarks(
bench_test_fn=bench_speed_fused_linear_cross_entropy,
kernel_operation_modes=["forward", "backward", "full"],
kernel_operation_modes=["forward", "backward", "full", "no-grad-forward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
Expand Down
107 changes: 55 additions & 52 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def liger_cross_entropy_kernel(
BLOCK_SIZE: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_SOFTCAPPING: tl.constexpr,
HAS_GRADIENTS: tl.constexpr,
):
"""
This kernel computes both cross entropy loss and the gradient of the input.
Expand Down Expand Up @@ -72,6 +73,7 @@ def liger_cross_entropy_kernel(
BLOCK_SIZE (int): The block size for Triton operations.
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
"""

# https://github.com/triton-lang/triton/issues/1058
Expand Down Expand Up @@ -155,58 +157,58 @@ def liger_cross_entropy_kernel(
# For 'sum' reduction, no normalization is applied:
# dx_y = softmax(x_y) - 1
# dx_i = softmax(x_i), for i ≠ y

for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(
X_ptr + X_offsets,
mask=X_offsets < n_cols,
other=float("-inf"),
# Ensure float32 precision for softmax calculation
).cast(tl.float32)
if HAS_SOFTCAPPING:
intermediate = tanh(X_block / softcap)
X_block = softcap * intermediate

if not HAS_WEIGHT:
# softmax(x_i)
X_block = tl.exp(X_block - m) / d
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
X_block += 2 * lse_square_scale * lse * X_block
# smoothing term
X_block += -eps
# special handle dx_y
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
# reduction scale
if reduction == "mean":
X_block = X_block / n_non_ignore
else:
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
softmax_X = tl.exp(X_block - m) / d
# derivative of original_loss
dloss_ori = (1 - label_smoothing) * softmax_X
# specially handle dx_y
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
dloss_ori = dloss_ori * weight_y
# derivative of smooth_loss
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
# derivative of z-loss
dz_loss = 2 * lse_square_scale * lse * softmax_X
# reduction scale
if reduction == "mean":
dloss_ori = dloss_ori / sum_non_ignore_weight
dloss_smooth = dloss_smooth / sum_non_ignore_weight
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
dz_loss = dz_loss / n_non_ignore
# derivative of total_loss
X_block = dloss_ori + dloss_smooth + dz_loss

# chain rule softcapping
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
if HAS_SOFTCAPPING:
X_block = X_block * (1 - intermediate * intermediate)

tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
if HAS_GRADIENTS:
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(
X_ptr + X_offsets,
mask=X_offsets < n_cols,
other=float("-inf"),
# Ensure float32 precision for softmax calculation
).cast(tl.float32)
if HAS_SOFTCAPPING:
intermediate = tanh(X_block / softcap)
X_block = softcap * intermediate

if not HAS_WEIGHT:
# softmax(x_i)
X_block = tl.exp(X_block - m) / d
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
X_block += 2 * lse_square_scale * lse * X_block
# smoothing term
X_block += -eps
# special handle dx_y
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
# reduction scale
if reduction == "mean":
X_block = X_block / n_non_ignore
else:
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
softmax_X = tl.exp(X_block - m) / d
# derivative of original_loss
dloss_ori = (1 - label_smoothing) * softmax_X
# specially handle dx_y
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
dloss_ori = dloss_ori * weight_y
# derivative of smooth_loss
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
# derivative of z-loss
dz_loss = 2 * lse_square_scale * lse * softmax_X
# reduction scale
if reduction == "mean":
dloss_ori = dloss_ori / sum_non_ignore_weight
dloss_smooth = dloss_smooth / sum_non_ignore_weight
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
dz_loss = dz_loss / n_non_ignore
# derivative of total_loss
X_block = dloss_ori + dloss_smooth + dz_loss

# chain rule softcapping
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
if HAS_SOFTCAPPING:
X_block = X_block * (1 - intermediate * intermediate)

tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)

# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
Expand Down Expand Up @@ -332,6 +334,7 @@ def cross_entropy_forward(
BLOCK_SIZE=BLOCK_SIZE,
HAS_WEIGHT=True if weight is not None else False,
HAS_SOFTCAPPING=True if softcap is not None else False,
HAS_GRADIENTS=_input.requires_grad,
# TODO: 32 seems to give the best performance
# Performance is quite sensitive to num_warps
num_warps=32 if not is_hip() else 16,
Expand Down
5 changes: 3 additions & 2 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def fused_linear_cross_entropy_forward(
RETURN_Z_LOSS=return_z_loss,
HAS_WEIGHT=True if ce_weight is not None else False,
HAS_SOFTCAPPING=True if softcap is not None else False,
HAS_GRADIENTS=_input.requires_grad,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
)
Expand All @@ -173,10 +174,10 @@ def fused_linear_cross_entropy_forward(

grad_input[start_idx:end_idx] = grad_logits_chunk @ weight

if grad_weight is not None:
if grad_weight is not None and _input.requires_grad:
grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()

if bias is not None:
if bias is not None and _input.requires_grad:
torch.add(
input=grad_bias,
other=grad_logits_chunk.sum(dim=0),
Expand Down
45 changes: 45 additions & 0 deletions test/transformers/test_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,28 @@ def _test_correctness_not_last_layer_with_other_params_once(
assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)


def _test_correctness_with_forward_only(target_ce, B, T, V, reduction, dtype, scalar, atol, rtol):
torch.manual_seed(0)
torch_ce = CrossEntropyLoss(reduction=reduction)

_tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar
_input = _tensor.detach().clone()
_input2 = _tensor.detach().clone()

target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)

with torch.no_grad():
output = torch_ce(_input, target)
output2 = target_ce(_input2, target)
assert torch.allclose(output, output2, atol=atol, rtol=rtol)

try:
# Try running backward on liger output
output2.backward(gradient=torch.ones_like(output))
except RuntimeError as e:
assert "does not require grad" in str(e)


def _test_correctness_functional(
B,
T,
Expand Down Expand Up @@ -1014,6 +1036,7 @@ def test_float32_internal():
RETURN_Z_LOSS=0, # False
HAS_WEIGHT=False,
HAS_SOFTCAPPING=False,
HAS_GRADIENTS=True,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
)
Expand Down Expand Up @@ -1042,6 +1065,7 @@ def test_float32_internal():
RETURN_Z_LOSS=0, # False
HAS_WEIGHT=False,
HAS_SOFTCAPPING=False,
HAS_GRADIENTS=True,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
)
Expand All @@ -1061,3 +1085,24 @@ def test_float32_internal():
def test_correctness_with_out_of_bounds_target_once(B, T, V, ignore_index):
liger_ce = LigerCrossEntropyLoss(ignore_index=ignore_index)
_test_correctness_with_out_of_bounds_target_once(liger_ce, B, T, V, ignore_index)


@pytest.mark.parametrize(
"B, T, V, ignore_index",
[
(2, 4096, 32000, -100),
(3, 423, 32000, 2),
],
)
@pytest.mark.parametrize("reduction", ["mean", "sum", "none"])
@pytest.mark.parametrize(
"dtype, scalar, atol, rtol",
[
(torch.float32, 1.0, 1e-4, 1e-4),
(torch.float16, 1.0, 1e-2, 1e-2),
(torch.bfloat16, 1.0, 1e-2, 1e-2),
],
)
def test_correctness_with_forward_only(B, T, V, ignore_index, reduction, dtype, scalar, atol, rtol):
liger_ce = LigerCrossEntropyLoss(ignore_index=ignore_index, reduction=reduction)
_test_correctness_with_forward_only(liger_ce, B, T, V, reduction, dtype, scalar, atol, rtol)
Loading
Loading