Skip to content

Commit a0b849e

Browse files
committed
feat(FLCE): expose accum_dtype for hf model monkey patch
Signed-off-by: Tcc0403 <[email protected]>
1 parent e6de786 commit a0b849e

File tree

3 files changed

+4
-2
lines changed

3 files changed

+4
-2
lines changed

src/liger_kernel/transformers/model/loss_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def fixed_fused_linear_cross_entropy(
1313
num_items_in_batch: Optional[int] = None,
1414
ignore_index: int = -100,
1515
final_logit_softcapping: Optional[float] = None,
16+
accum_dtype: Optional[torch.dtype] = None,
1617
**kwargs,
1718
):
1819
reduction = "sum" if num_items_in_batch is not None else "mean"
@@ -23,6 +24,7 @@ def fixed_fused_linear_cross_entropy(
2324
reduction=reduction,
2425
ignore_index=ignore_index,
2526
softcap=final_logit_softcapping,
27+
accum_dtype=accum_dtype,
2628
)
2729
if reduction == "sum":
2830
loss = loss / num_items_in_batch

test/convergence/bf16/test_mini_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -926,7 +926,7 @@ def run_mini_model(
926926
for i in range(num_steps):
927927
batch = next(loader_iter).to(model.device)
928928
optimizer.zero_grad()
929-
output = model(**batch)
929+
output = model(**batch, accum_dtype=torch.float32)
930930
output.loss.backward()
931931
optimizer.step()
932932
print(f"Step {i}, Loss: {output.loss.item()}")

test/convergence/bf16/test_mini_models_multimodal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,7 @@ def run_mini_model_multimodal(
860860
for i in range(num_steps):
861861
batch = next(loader_iter).to(model.device)
862862
optimizer.zero_grad()
863-
output = model(**batch)
863+
output = model(**batch, accum_dtype=torch.float32)
864864
output.loss.backward()
865865
optimizer.step()
866866

0 commit comments

Comments
 (0)