Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions fms_mo/utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def eval_llm_1GPU(qcfg, model, test_dataset, pre_cache_func=None, **kwargs): #
logger.info("All blocks are computed for evaluation")

nlls = []
logits_scaling = getattr(model.config, "logits_scaling", 1)
# for i, data_mb in enumerate(dloader): #if using dloader.
for i in tqdm(range(qcfg["n_samples"]), desc="Final Evaluating..."):
hidden_states = qcfg["cached_input"][i].to(dev)
Expand All @@ -106,9 +107,7 @@ def eval_llm_1GPU(qcfg, model, test_dataset, pre_cache_func=None, **kwargs): #
hidden_states = ln_f(hidden_states)
lm_head.to(dev)
lm_logits = lm_head(hidden_states)

if model.config.model_type == "granite":
lm_logits /= model.config.logits_scaling
lm_logits /= logits_scaling

# Shift so that tokens < n predict n
shift_logits = lm_logits[:, :-1, :].contiguous().float()
Expand Down
Loading