diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 526a4f2ff..6ad75f93b 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -414,6 +414,8 @@ def forward( Returns: tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None. """ + input_requires_grad = _input.requires_grad + loss, z_loss, _input = cross_entropy_forward( _input, target, @@ -428,7 +430,8 @@ def forward( # TODO: investigation # If we don't detach the _input tensor, the memory will double # Not sure why but seems that there will be a time both grad and value exist but in different location - ctx.save_for_backward(_input.detach()) + if input_requires_grad: + ctx.save_for_backward(_input.detach()) ctx.return_z_loss = return_z_loss return loss, z_loss diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index d2fecf047..9a33a42a9 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -31,6 +31,8 @@ def fused_linear_cross_entropy_forward( assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}" device = _input.device + input_requires_grad = _input.requires_grad + # inputs have shape: BT x H # materialized activations will have shape: BT x V # the increase in memory = BT x V @@ -49,12 +51,13 @@ def fused_linear_cross_entropy_forward( grad_input = torch.zeros_like(_input, device=device) # we use fp32 for loss and gradients accumulator - if accum_dtype is None: - grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None - grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None - else: - grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None - grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None + if input_requires_grad: + if accum_dtype is None: + grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None + grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None + else: + grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None + grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None loss_1d = torch.zeros(BT, dtype=torch.float32, device=device) z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None @@ -150,7 +153,7 @@ def fused_linear_cross_entropy_forward( RETURN_Z_LOSS=return_z_loss, HAS_WEIGHT=True if ce_weight is not None else False, HAS_SOFTCAPPING=True if softcap is not None else False, - HAS_GRADIENTS=_input.requires_grad, + HAS_GRADIENTS=input_requires_grad, BLOCK_SIZE=BLOCK_SIZE, num_warps=32 if not is_hip() else 16, ) @@ -172,12 +175,13 @@ def fused_linear_cross_entropy_forward( scaling_factors_expanded = scaling_factors.unsqueeze(-1) # chunk_size x 1 grad_logits_chunk = grad_logits_chunk * scaling_factors_expanded - grad_input[start_idx:end_idx] = grad_logits_chunk @ weight + if input_requires_grad: + grad_input[start_idx:end_idx] = grad_logits_chunk @ weight - if grad_weight is not None and _input.requires_grad: + if grad_weight is not None and input_requires_grad: grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float() - if bias is not None and _input.requires_grad: + if bias is not None and input_requires_grad: torch.add( input=grad_bias, other=grad_logits_chunk.sum(dim=0),