Skip to content

Commit 9ab603a

Browse files
committed
feat(ce,flce): decouple gradients computation for no_grad mode
1 parent cf43acd commit 9ab603a

File tree

6 files changed

+238
-60
lines changed

6 files changed

+238
-60
lines changed

benchmark/scripts/benchmark_cross_entropy.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def bench_memory_cross_entropy(
2323

2424
V = input.x
2525
provider = input.kernel_provider
26+
mode = input.kernel_operation_mode
2627
B = input.extra_benchmark_config["B"]
2728
T = input.extra_benchmark_config["T"]
2829

@@ -39,7 +40,11 @@ def full():
3940
y = fwd()
4041
y.backward()
4142

42-
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
43+
if mode == "full":
44+
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
45+
elif mode == "no-grad-full":
46+
with torch.no_grad():
47+
mem_50, mem_20, mem_80 = _test_memory(fwd, quantiles=QUANTILES)
4348
return SingleBenchmarkRunOutput(
4449
y_20=mem_20,
4550
y_50=mem_50,
@@ -70,6 +75,9 @@ def fwd():
7075

7176
if mode == "forward":
7277
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES)
78+
elif mode == "no-grad-forward":
79+
with torch.no_grad():
80+
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES)
7381
elif mode == "backward":
7482
y = fwd()
7583

@@ -109,14 +117,14 @@ def full():
109117

110118
run_benchmarks(
111119
bench_test_fn=bench_speed_cross_entropy,
112-
kernel_operation_modes=["forward", "backward", "full"],
120+
kernel_operation_modes=["forward", "backward", "full", "no-grad-forward"],
113121
metric_name="speed",
114122
metric_unit="ms",
115123
**common_configs,
116124
)
117125
run_benchmarks(
118126
bench_test_fn=bench_memory_cross_entropy,
119-
kernel_operation_modes=["full"],
127+
kernel_operation_modes=["full", "no-grad-full"],
120128
metric_name="memory",
121129
metric_unit="MB",
122130
**common_configs,

benchmark/scripts/benchmark_fused_linear_cross_entropy.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def bench_memory_fused_linear_cross_entropy(
5858
V = input.extra_benchmark_config["V"]
5959
dtype = input.extra_benchmark_config["dtype"]
6060
provider = input.kernel_provider
61+
mode = input.kernel_operation_mode
6162

6263
torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
6364
liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)
@@ -78,7 +79,12 @@ def full():
7879
y = fwd()
7980
y.backward()
8081

81-
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
82+
if mode == "full":
83+
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
84+
elif mode == "no-grad-full":
85+
with torch.no_grad():
86+
mem_50, mem_20, mem_80 = _test_memory(fwd, _iter=10, quantiles=QUANTILES)
87+
8288
return SingleBenchmarkRunOutput(
8389
y_20=mem_20,
8490
y_50=mem_50,
@@ -122,6 +128,13 @@ def fwd():
122128
rep=100,
123129
quantiles=QUANTILES,
124130
)
131+
elif mode == "no-grad-forward":
132+
with torch.no_grad():
133+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
134+
fwd,
135+
rep=100,
136+
quantiles=QUANTILES,
137+
)
125138
elif mode == "backward":
126139
y = fwd()
127140

@@ -164,14 +177,14 @@ def full():
164177

165178
run_benchmarks(
166179
bench_test_fn=bench_speed_fused_linear_cross_entropy,
167-
kernel_operation_modes=["forward", "backward", "full"],
180+
kernel_operation_modes=["forward", "backward", "full", "no-grad-forward"],
168181
metric_name="speed",
169182
metric_unit="ms",
170183
**common_configs,
171184
)
172185
run_benchmarks(
173186
bench_test_fn=bench_memory_fused_linear_cross_entropy,
174-
kernel_operation_modes=["full"],
187+
kernel_operation_modes=["full", "no-grad-full"],
175188
metric_name="memory",
176189
metric_unit="MB",
177190
**common_configs,

src/liger_kernel/ops/cross_entropy.py

Lines changed: 55 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def liger_cross_entropy_kernel(
4545
BLOCK_SIZE: tl.constexpr,
4646
HAS_WEIGHT: tl.constexpr,
4747
HAS_SOFTCAPPING: tl.constexpr,
48+
HAS_GRADIENTS: tl.constexpr,
4849
):
4950
"""
5051
This kernel computes both cross entropy loss and the gradient of the input.
@@ -72,6 +73,7 @@ def liger_cross_entropy_kernel(
7273
BLOCK_SIZE (int): The block size for Triton operations.
7374
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
7475
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
76+
HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
7577
"""
7678

7779
# https://github.com/triton-lang/triton/issues/1058
@@ -155,58 +157,58 @@ def liger_cross_entropy_kernel(
155157
# For 'sum' reduction, no normalization is applied:
156158
# dx_y = softmax(x_y) - 1
157159
# dx_i = softmax(x_i), for i ≠ y
158-
159-
for i in range(0, n_cols, BLOCK_SIZE):
160-
X_offsets = i + tl.arange(0, BLOCK_SIZE)
161-
X_block = tl.load(
162-
X_ptr + X_offsets,
163-
mask=X_offsets < n_cols,
164-
other=float("-inf"),
165-
# Ensure float32 precision for softmax calculation
166-
).cast(tl.float32)
167-
if HAS_SOFTCAPPING:
168-
intermediate = tanh(X_block / softcap)
169-
X_block = softcap * intermediate
170-
171-
if not HAS_WEIGHT:
172-
# softmax(x_i)
173-
X_block = tl.exp(X_block - m) / d
174-
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
175-
X_block += 2 * lse_square_scale * lse * X_block
176-
# smoothing term
177-
X_block += -eps
178-
# special handle dx_y
179-
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
180-
# reduction scale
181-
if reduction == "mean":
182-
X_block = X_block / n_non_ignore
183-
else:
184-
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
185-
softmax_X = tl.exp(X_block - m) / d
186-
# derivative of original_loss
187-
dloss_ori = (1 - label_smoothing) * softmax_X
188-
# specially handle dx_y
189-
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
190-
dloss_ori = dloss_ori * weight_y
191-
# derivative of smooth_loss
192-
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
193-
# derivative of z-loss
194-
dz_loss = 2 * lse_square_scale * lse * softmax_X
195-
# reduction scale
196-
if reduction == "mean":
197-
dloss_ori = dloss_ori / sum_non_ignore_weight
198-
dloss_smooth = dloss_smooth / sum_non_ignore_weight
199-
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
200-
dz_loss = dz_loss / n_non_ignore
201-
# derivative of total_loss
202-
X_block = dloss_ori + dloss_smooth + dz_loss
203-
204-
# chain rule softcapping
205-
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
206-
if HAS_SOFTCAPPING:
207-
X_block = X_block * (1 - intermediate * intermediate)
208-
209-
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
160+
if HAS_GRADIENTS:
161+
for i in range(0, n_cols, BLOCK_SIZE):
162+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
163+
X_block = tl.load(
164+
X_ptr + X_offsets,
165+
mask=X_offsets < n_cols,
166+
other=float("-inf"),
167+
# Ensure float32 precision for softmax calculation
168+
).cast(tl.float32)
169+
if HAS_SOFTCAPPING:
170+
intermediate = tanh(X_block / softcap)
171+
X_block = softcap * intermediate
172+
173+
if not HAS_WEIGHT:
174+
# softmax(x_i)
175+
X_block = tl.exp(X_block - m) / d
176+
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
177+
X_block += 2 * lse_square_scale * lse * X_block
178+
# smoothing term
179+
X_block += -eps
180+
# special handle dx_y
181+
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
182+
# reduction scale
183+
if reduction == "mean":
184+
X_block = X_block / n_non_ignore
185+
else:
186+
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
187+
softmax_X = tl.exp(X_block - m) / d
188+
# derivative of original_loss
189+
dloss_ori = (1 - label_smoothing) * softmax_X
190+
# specially handle dx_y
191+
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
192+
dloss_ori = dloss_ori * weight_y
193+
# derivative of smooth_loss
194+
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
195+
# derivative of z-loss
196+
dz_loss = 2 * lse_square_scale * lse * softmax_X
197+
# reduction scale
198+
if reduction == "mean":
199+
dloss_ori = dloss_ori / sum_non_ignore_weight
200+
dloss_smooth = dloss_smooth / sum_non_ignore_weight
201+
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
202+
dz_loss = dz_loss / n_non_ignore
203+
# derivative of total_loss
204+
X_block = dloss_ori + dloss_smooth + dz_loss
205+
206+
# chain rule softcapping
207+
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
208+
if HAS_SOFTCAPPING:
209+
X_block = X_block * (1 - intermediate * intermediate)
210+
211+
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
210212

211213
# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
212214
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
@@ -332,6 +334,7 @@ def cross_entropy_forward(
332334
BLOCK_SIZE=BLOCK_SIZE,
333335
HAS_WEIGHT=True if weight is not None else False,
334336
HAS_SOFTCAPPING=True if softcap is not None else False,
337+
HAS_GRADIENTS=_input.requires_grad,
335338
# TODO: 32 seems to give the best performance
336339
# Performance is quite sensitive to num_warps
337340
num_warps=32 if not is_hip() else 16,

src/liger_kernel/ops/fused_linear_cross_entropy.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def fused_linear_cross_entropy_forward(
150150
RETURN_Z_LOSS=return_z_loss,
151151
HAS_WEIGHT=True if ce_weight is not None else False,
152152
HAS_SOFTCAPPING=True if softcap is not None else False,
153+
HAS_GRADIENTS=logits_chunk.requires_grad,
153154
BLOCK_SIZE=BLOCK_SIZE,
154155
num_warps=32 if not is_hip() else 16,
155156
)
@@ -173,10 +174,10 @@ def fused_linear_cross_entropy_forward(
173174

174175
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
175176

176-
if grad_weight is not None:
177+
if grad_weight is not None and _input_chunk.requires_grad:
177178
grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
178179

179-
if bias is not None:
180+
if bias is not None and _input_chunk.requires_grad:
180181
torch.add(
181182
input=grad_bias,
182183
other=grad_logits_chunk.sum(dim=0),

test/transformers/test_cross_entropy.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,27 @@ def _test_correctness_not_last_layer_with_other_params_once(
454454
loss2.backward(gradient=grad_output)
455455
assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)
456456

457+
def _test_correctness_with_forward_only(target_ce, B, T, V, reduction, dtype, scalar, atol, rtol):
458+
torch.manual_seed(0)
459+
torch_ce = CrossEntropyLoss(reduction=reduction)
460+
461+
_tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar
462+
_input = _tensor.detach().clone()
463+
_input2 = _tensor.detach().clone()
464+
465+
target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
466+
467+
with torch.no_grad():
468+
output = torch_ce(_input, target)
469+
output2 = target_ce(_input2, target)
470+
assert torch.allclose(output, output2, atol=atol, rtol=rtol)
471+
472+
try:
473+
# Try running backward on liger output
474+
output2.backward(gradient=torch.ones_like(output))
475+
except RuntimeError as e:
476+
assert "does not require grad" in str(e)
477+
457478

458479
def _test_correctness_functional(
459480
B,
@@ -1061,3 +1082,23 @@ def test_float32_internal():
10611082
def test_correctness_with_out_of_bounds_target_once(B, T, V, ignore_index):
10621083
liger_ce = LigerCrossEntropyLoss(ignore_index=ignore_index)
10631084
_test_correctness_with_out_of_bounds_target_once(liger_ce, B, T, V, ignore_index)
1085+
1086+
@pytest.mark.parametrize(
1087+
"B, T, V, ignore_index",
1088+
[
1089+
(2, 4096, 32000, -100),
1090+
(3, 423, 32000, 2),
1091+
],
1092+
)
1093+
@pytest.mark.parametrize("reduction", ["mean", "sum", "none"])
1094+
@pytest.mark.parametrize(
1095+
"dtype, scalar, atol, rtol",
1096+
[
1097+
(torch.float32, 1.0, 1e-4, 1e-4),
1098+
(torch.float16, 1.0, 1e-2, 1e-2),
1099+
(torch.bfloat16, 1.0, 1e-2, 1e-2),
1100+
],
1101+
)
1102+
def test_correctness_with_forward_only(B, T, V, ignore_index, reduction, dtype, scalar, atol, rtol):
1103+
liger_ce = LigerCrossEntropyLoss(ignore_index=ignore_index, reduction=reduction)
1104+
_test_correctness_with_forward_only(liger_ce, B, T, V, reduction, dtype, scalar, atol, rtol)

0 commit comments

Comments
 (0)