-
Notifications
You must be signed in to change notification settings - Fork 459
Closed
Description
🐛 Describe the bug
After #906 , trying to run eval in no_grad mode results in the error:
File "/home/ket/ws/camfer/BlobLearn/.venv/lib/python3.12/site-packages/liger_kernel/ops/fused_linear_cross_entropy.py", line 181, in fused_linear_cross_entropy_forward
if grad_weight is not None and input_requires_grad:
^^^^^^^^^^^
UnboundLocalError: cannot access local variable 'grad_weight' where it is not associated with a valueLooking at the commit, if input_requires_grad is False, grad_weight is no longer set at all! It just needs to be initialized to None I think.
I can put a PR up ASAP but wanted to check understanding as I'm not sure how the unit tests are passing currently. Thanks!
Reproduce
import torch
import torch.nn as nn
from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss
vocab_size, hidden_dim, num_tokens = 1000, 512, 256
device = "cuda" if torch.cuda.is_available() else "cpu"
linear = nn.Linear(hidden_dim, vocab_size, bias=False).to(device)
fused_loss_fn = LigerFusedLinearCrossEntropyLoss()
hidden_states = torch.randn(num_tokens, hidden_dim, device=device)
labels = torch.randint(0, vocab_size, (num_tokens,), device=device)
with torch.no_grad():
loss = fused_loss_fn(linear.weight, hidden_states, labels)
print(f"Loss: {loss.item()}")Versions
Environment Report:
Operating System: Linux-6.17.3-arch2-1-x86_64-with-glibc2.42
Python version: 3.12.12
Liger Kernel version: 0.6.3
PyTorch version: 2.9.0+cu128
CUDA version: 12.8
HIP(ROCm) version: Not available
Triton version: 3.5.0
Transformers version: 4.57.1
XPU version: XPU Not Available
Metadata
Metadata
Assignees
Labels
No labels