File tree Expand file tree Collapse file tree 3 files changed +4
-2
lines changed
src/liger_kernel/transformers/model Expand file tree Collapse file tree 3 files changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -13,6 +13,7 @@ def fixed_fused_linear_cross_entropy(
1313 num_items_in_batch : Optional [int ] = None ,
1414 ignore_index : int = - 100 ,
1515 final_logit_softcapping : Optional [float ] = None ,
16+ accum_dtype : Optional [torch .dtype ] = None ,
1617 ** kwargs ,
1718):
1819 reduction = "sum" if num_items_in_batch is not None else "mean"
@@ -23,6 +24,7 @@ def fixed_fused_linear_cross_entropy(
2324 reduction = reduction ,
2425 ignore_index = ignore_index ,
2526 softcap = final_logit_softcapping ,
27+ accum_dtype = accum_dtype ,
2628 )
2729 if reduction == "sum" :
2830 loss = loss / num_items_in_batch
Original file line number Diff line number Diff line change @@ -926,7 +926,7 @@ def run_mini_model(
926926 for i in range (num_steps ):
927927 batch = next (loader_iter ).to (model .device )
928928 optimizer .zero_grad ()
929- output = model (** batch )
929+ output = model (** batch , accum_dtype = torch . float32 )
930930 output .loss .backward ()
931931 optimizer .step ()
932932 print (f"Step { i } , Loss: { output .loss .item ()} " )
Original file line number Diff line number Diff line change @@ -860,7 +860,7 @@ def run_mini_model_multimodal(
860860 for i in range (num_steps ):
861861 batch = next (loader_iter ).to (model .device )
862862 optimizer .zero_grad ()
863- output = model (** batch )
863+ output = model (** batch , accum_dtype = torch . float32 )
864864 output .loss .backward ()
865865 optimizer .step ()
866866
You can’t perform that action at this time.
0 commit comments