diff --git a/benchmark/scripts/benchmark_cross_entropy.py b/benchmark/scripts/benchmark_cross_entropy.py index e36fd1db2..cdd61814a 100644 --- a/benchmark/scripts/benchmark_cross_entropy.py +++ b/benchmark/scripts/benchmark_cross_entropy.py @@ -70,6 +70,9 @@ def fwd(): if mode == "forward": ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) + elif mode == "no-grad-forward": + with torch.no_grad(): + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) elif mode == "backward": y = fwd() @@ -109,7 +112,7 @@ def full(): run_benchmarks( bench_test_fn=bench_speed_cross_entropy, - kernel_operation_modes=["forward", "backward", "full"], + kernel_operation_modes=["forward", "backward", "full", "no-grad-forward"], metric_name="speed", metric_unit="ms", **common_configs, diff --git a/benchmark/scripts/benchmark_fused_linear_cross_entropy.py b/benchmark/scripts/benchmark_fused_linear_cross_entropy.py index d0af655b3..4d36a66a6 100644 --- a/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +++ b/benchmark/scripts/benchmark_fused_linear_cross_entropy.py @@ -59,26 +59,26 @@ def bench_memory_fused_linear_cross_entropy( dtype = input.extra_benchmark_config["dtype"] provider = input.kernel_provider - torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device) - liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device) - liger_lm_head_ce_fp32_accum = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device) + lm_head_ce = None + if provider == "liger": + lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device) + elif provider == "liger-fp32-accum": + lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device) + else: + lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device) _input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device) target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1) def fwd(): - if provider == "liger": - return liger_lm_head_ce(_input, target) - elif provider == "liger-fp32-accum": - return liger_lm_head_ce_fp32_accum(_input, target) - elif provider == "huggingface": - return torch_lm_head_ce(_input, target) + return lm_head_ce(_input, target) def full(): y = fwd() y.backward() mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( y_20=mem_20, y_50=mem_50, @@ -101,20 +101,19 @@ def bench_speed_fused_linear_cross_entropy( provider = input.kernel_provider mode = input.kernel_operation_mode - torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device) - liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device) - liger_lm_head_ce_fp32_accum = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device) + lm_head_ce = None + if provider == "liger": + lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device) + elif provider == "liger-fp32-accum": + lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device) + else: + lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device) _input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device) target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1) def fwd(): - if provider == "liger": - return liger_lm_head_ce(_input, target) - elif provider == "liger-fp32-accum": - return liger_lm_head_ce_fp32_accum(_input, target) - elif provider == "huggingface": - return torch_lm_head_ce(_input, target) + return lm_head_ce(_input, target) if mode == "forward": ms_50, ms_20, ms_80 = triton.testing.do_bench( @@ -122,6 +121,13 @@ def fwd(): rep=100, quantiles=QUANTILES, ) + elif mode == "no-grad-forward": + with torch.no_grad(): + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) elif mode == "backward": y = fwd() @@ -164,7 +170,7 @@ def full(): run_benchmarks( bench_test_fn=bench_speed_fused_linear_cross_entropy, - kernel_operation_modes=["forward", "backward", "full"], + kernel_operation_modes=["forward", "backward", "full", "no-grad-forward"], metric_name="speed", metric_unit="ms", **common_configs, diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 9c886b587..526a4f2ff 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -45,6 +45,7 @@ def liger_cross_entropy_kernel( BLOCK_SIZE: tl.constexpr, HAS_WEIGHT: tl.constexpr, HAS_SOFTCAPPING: tl.constexpr, + HAS_GRADIENTS: tl.constexpr, ): """ This kernel computes both cross entropy loss and the gradient of the input. @@ -72,6 +73,7 @@ def liger_cross_entropy_kernel( BLOCK_SIZE (int): The block size for Triton operations. HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes. HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not. + HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass. """ # https://github.com/triton-lang/triton/issues/1058 @@ -155,58 +157,58 @@ def liger_cross_entropy_kernel( # For 'sum' reduction, no normalization is applied: # dx_y = softmax(x_y) - 1 # dx_i = softmax(x_i), for i ≠ y - - for i in range(0, n_cols, BLOCK_SIZE): - X_offsets = i + tl.arange(0, BLOCK_SIZE) - X_block = tl.load( - X_ptr + X_offsets, - mask=X_offsets < n_cols, - other=float("-inf"), - # Ensure float32 precision for softmax calculation - ).cast(tl.float32) - if HAS_SOFTCAPPING: - intermediate = tanh(X_block / softcap) - X_block = softcap * intermediate - - if not HAS_WEIGHT: - # softmax(x_i) - X_block = tl.exp(X_block - m) / d - # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) - X_block += 2 * lse_square_scale * lse * X_block - # smoothing term - X_block += -eps - # special handle dx_y - X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing)) - # reduction scale - if reduction == "mean": - X_block = X_block / n_non_ignore - else: - weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) - softmax_X = tl.exp(X_block - m) / d - # derivative of original_loss - dloss_ori = (1 - label_smoothing) * softmax_X - # specially handle dx_y - dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing)) - dloss_ori = dloss_ori * weight_y - # derivative of smooth_loss - dloss_smooth = eps * (-weight_block + softmax_X * weight_sum) - # derivative of z-loss - dz_loss = 2 * lse_square_scale * lse * softmax_X - # reduction scale - if reduction == "mean": - dloss_ori = dloss_ori / sum_non_ignore_weight - dloss_smooth = dloss_smooth / sum_non_ignore_weight - # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. - dz_loss = dz_loss / n_non_ignore - # derivative of total_loss - X_block = dloss_ori + dloss_smooth + dz_loss - - # chain rule softcapping - # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) - if HAS_SOFTCAPPING: - X_block = X_block * (1 - intermediate * intermediate) - - tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) + if HAS_GRADIENTS: + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load( + X_ptr + X_offsets, + mask=X_offsets < n_cols, + other=float("-inf"), + # Ensure float32 precision for softmax calculation + ).cast(tl.float32) + if HAS_SOFTCAPPING: + intermediate = tanh(X_block / softcap) + X_block = softcap * intermediate + + if not HAS_WEIGHT: + # softmax(x_i) + X_block = tl.exp(X_block - m) / d + # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) + X_block += 2 * lse_square_scale * lse * X_block + # smoothing term + X_block += -eps + # special handle dx_y + X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing)) + # reduction scale + if reduction == "mean": + X_block = X_block / n_non_ignore + else: + weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) + softmax_X = tl.exp(X_block - m) / d + # derivative of original_loss + dloss_ori = (1 - label_smoothing) * softmax_X + # specially handle dx_y + dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing)) + dloss_ori = dloss_ori * weight_y + # derivative of smooth_loss + dloss_smooth = eps * (-weight_block + softmax_X * weight_sum) + # derivative of z-loss + dz_loss = 2 * lse_square_scale * lse * softmax_X + # reduction scale + if reduction == "mean": + dloss_ori = dloss_ori / sum_non_ignore_weight + dloss_smooth = dloss_smooth / sum_non_ignore_weight + # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. + dz_loss = dz_loss / n_non_ignore + # derivative of total_loss + X_block = dloss_ori + dloss_smooth + dz_loss + + # chain rule softcapping + # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) + if HAS_SOFTCAPPING: + X_block = X_block * (1 - intermediate * intermediate) + + tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34 @@ -332,6 +334,7 @@ def cross_entropy_forward( BLOCK_SIZE=BLOCK_SIZE, HAS_WEIGHT=True if weight is not None else False, HAS_SOFTCAPPING=True if softcap is not None else False, + HAS_GRADIENTS=_input.requires_grad, # TODO: 32 seems to give the best performance # Performance is quite sensitive to num_warps num_warps=32 if not is_hip() else 16, diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 1cf59af7b..d2fecf047 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -150,6 +150,7 @@ def fused_linear_cross_entropy_forward( RETURN_Z_LOSS=return_z_loss, HAS_WEIGHT=True if ce_weight is not None else False, HAS_SOFTCAPPING=True if softcap is not None else False, + HAS_GRADIENTS=_input.requires_grad, BLOCK_SIZE=BLOCK_SIZE, num_warps=32 if not is_hip() else 16, ) @@ -173,10 +174,10 @@ def fused_linear_cross_entropy_forward( grad_input[start_idx:end_idx] = grad_logits_chunk @ weight - if grad_weight is not None: + if grad_weight is not None and _input.requires_grad: grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float() - if bias is not None: + if bias is not None and _input.requires_grad: torch.add( input=grad_bias, other=grad_logits_chunk.sum(dim=0), diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index d1e5ee76c..5a98bb1f1 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -455,6 +455,28 @@ def _test_correctness_not_last_layer_with_other_params_once( assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) +def _test_correctness_with_forward_only(target_ce, B, T, V, reduction, dtype, scalar, atol, rtol): + torch.manual_seed(0) + torch_ce = CrossEntropyLoss(reduction=reduction) + + _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar + _input = _tensor.detach().clone() + _input2 = _tensor.detach().clone() + + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + with torch.no_grad(): + output = torch_ce(_input, target) + output2 = target_ce(_input2, target) + assert torch.allclose(output, output2, atol=atol, rtol=rtol) + + try: + # Try running backward on liger output + output2.backward(gradient=torch.ones_like(output)) + except RuntimeError as e: + assert "does not require grad" in str(e) + + def _test_correctness_functional( B, T, @@ -1014,6 +1036,7 @@ def test_float32_internal(): RETURN_Z_LOSS=0, # False HAS_WEIGHT=False, HAS_SOFTCAPPING=False, + HAS_GRADIENTS=True, BLOCK_SIZE=BLOCK_SIZE, num_warps=32 if not is_hip() else 16, ) @@ -1042,6 +1065,7 @@ def test_float32_internal(): RETURN_Z_LOSS=0, # False HAS_WEIGHT=False, HAS_SOFTCAPPING=False, + HAS_GRADIENTS=True, BLOCK_SIZE=BLOCK_SIZE, num_warps=32 if not is_hip() else 16, ) @@ -1061,3 +1085,24 @@ def test_float32_internal(): def test_correctness_with_out_of_bounds_target_once(B, T, V, ignore_index): liger_ce = LigerCrossEntropyLoss(ignore_index=ignore_index) _test_correctness_with_out_of_bounds_target_once(liger_ce, B, T, V, ignore_index) + + +@pytest.mark.parametrize( + "B, T, V, ignore_index", + [ + (2, 4096, 32000, -100), + (3, 423, 32000, 2), + ], +) +@pytest.mark.parametrize("reduction", ["mean", "sum", "none"]) +@pytest.mark.parametrize( + "dtype, scalar, atol, rtol", + [ + (torch.float32, 1.0, 1e-4, 1e-4), + (torch.float16, 1.0, 1e-2, 1e-2), + (torch.bfloat16, 1.0, 1e-2, 1e-2), + ], +) +def test_correctness_with_forward_only(B, T, V, ignore_index, reduction, dtype, scalar, atol, rtol): + liger_ce = LigerCrossEntropyLoss(ignore_index=ignore_index, reduction=reduction) + _test_correctness_with_forward_only(liger_ce, B, T, V, reduction, dtype, scalar, atol, rtol) diff --git a/test/transformers/test_fused_linear_cross_entropy.py b/test/transformers/test_fused_linear_cross_entropy.py index 00811fc31..3d409e745 100644 --- a/test/transformers/test_fused_linear_cross_entropy.py +++ b/test/transformers/test_fused_linear_cross_entropy.py @@ -231,6 +231,119 @@ def test_correctness( ) +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (4, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "reduction, scalar, dtype, atol, rtol", + [ + ("mean", 1.0, torch.bfloat16, 5e-3, 5e-2), + ("mean", 1.0, torch.float32, 1e-5, 5e-4), + ("sum", 1.0, torch.bfloat16, 5e-0, 5e1), + ("sum", 1.0, torch.float32, 1e-3, 5e-2), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize( + "has_ce_weight, label_smoothing, ignore_index, lse_square_scale, softcap, return_z_loss, accum_dtype", + [ + (False, 0, -100, 0, None, False, None), + # Pass non-default values once to ensure all params work along + (True, 0.1, 42, 1e-4, 30.0, True, torch.float32), + ], +) +def test_correctness_with_forward_only( + B, + T, + H, + V, + scalar, + dtype, + bias, + has_ce_weight, + lse_square_scale, + label_smoothing, + ignore_index, + reduction, + softcap, + return_z_loss, + accum_dtype, + atol, + rtol, +): + if has_ce_weight: + ce_weight = torch.rand(V, device=device, dtype=torch.float32) + else: + ce_weight = None + torch_lm_head_ce = TorchLMHeadCE( + H=H, + V=V, + bias=bias, + ce_weight=ce_weight, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + ignore_index=ignore_index, + reduction=reduction, + softcap=softcap, + return_z_loss=return_z_loss, + dtype=dtype, + ).to(device) + liger_lm_head_ce = LigerLMHeadCE( + H=H, + V=V, + bias=bias, + ce_weight=ce_weight, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + ignore_index=ignore_index, + reduction=reduction, + softcap=softcap, + return_z_loss=return_z_loss, + dtype=dtype, + accum_dtype=accum_dtype, + ).to(device) + + # init the linear in all CEs with the same weights + torch_lm_head_ce.lin.weight.data = liger_lm_head_ce.lin.weight.data = torch.rand(V, H, device=device, dtype=dtype) + + if bias: + torch_lm_head_ce.lin.bias.data = liger_lm_head_ce.lin.bias.data = torch.rand(V, device=device, dtype=dtype) + + _tensor = torch.randn(B * T, H, device=device, dtype=dtype) * scalar + _input1 = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint( + 1, B * T // 2, (1,) + ).item() # Random number of elements to set to ignore_index + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] # Randomly select indices + target[indices_to_assign] = ignore_index + + with torch.no_grad(): + if return_z_loss: + output1, z_output1 = torch_lm_head_ce(_input1, target) + output2, z_output2 = liger_lm_head_ce(_input2, target) + else: + output1 = torch_lm_head_ce(_input1, target) + output2 = liger_lm_head_ce(_input2, target) + + assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) + if return_z_loss: + assert_verbose_allclose(z_output1, z_output2, atol=atol, rtol=rtol) + + try: + grad_output = torch.rand_like(output1) + output2.backward(gradient=grad_output) + except RuntimeError as e: + assert "does not require grad" in str(e) + + @pytest.mark.parametrize( "B, T, H, V", [