Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,8 @@ def forward(
Returns:
tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
"""
input_requires_grad = _input.requires_grad

loss, z_loss, _input = cross_entropy_forward(
_input,
target,
Expand All @@ -428,7 +430,8 @@ def forward(
# TODO: investigation
# If we don't detach the _input tensor, the memory will double
# Not sure why but seems that there will be a time both grad and value exist but in different location
ctx.save_for_backward(_input.detach())
if input_requires_grad:
ctx.save_for_backward(_input.detach())
ctx.return_z_loss = return_z_loss

return loss, z_loss
Expand Down
24 changes: 14 additions & 10 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def fused_linear_cross_entropy_forward(
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
device = _input.device

input_requires_grad = _input.requires_grad

# inputs have shape: BT x H
# materialized activations will have shape: BT x V
# the increase in memory = BT x V
Expand All @@ -49,12 +51,13 @@ def fused_linear_cross_entropy_forward(
grad_input = torch.zeros_like(_input, device=device)

# we use fp32 for loss and gradients accumulator
if accum_dtype is None:
grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
else:
grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
if input_requires_grad:
if accum_dtype is None:
grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
else:
grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None

loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
Expand Down Expand Up @@ -150,7 +153,7 @@ def fused_linear_cross_entropy_forward(
RETURN_Z_LOSS=return_z_loss,
HAS_WEIGHT=True if ce_weight is not None else False,
HAS_SOFTCAPPING=True if softcap is not None else False,
HAS_GRADIENTS=_input.requires_grad,
HAS_GRADIENTS=input_requires_grad,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
)
Expand All @@ -172,12 +175,13 @@ def fused_linear_cross_entropy_forward(
scaling_factors_expanded = scaling_factors.unsqueeze(-1) # chunk_size x 1
grad_logits_chunk = grad_logits_chunk * scaling_factors_expanded

grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
if input_requires_grad:
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight

if grad_weight is not None and _input.requires_grad:
if grad_weight is not None and input_requires_grad:
grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()

if bias is not None and _input.requires_grad:
if bias is not None and input_requires_grad:
torch.add(
input=grad_bias,
other=grad_logits_chunk.sum(dim=0),
Expand Down
Loading