Skip to content

Commit

Permalink
Add more specific comments
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Oct 30, 2024
1 parent 48b0bb6 commit b0987a3
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2167,6 +2167,7 @@ def cross_entropy_loss(logits, labels):
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]

# Neuron-specific change compared to the original implementation in `trl`:
# It is important to mark the step here to materialize the graph and tensors otherwise the compiler fails in
# `get_batch_loss_metrics` when adding `policy_rejected_logits` and `policy_chosen_logits` to the `metrics`.
xm.mark_step()
Expand All @@ -2181,12 +2182,17 @@ def odds_ratio_loss(
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:

# Neuron-specific change compared to the original implementation in `trl`, the original implementation is:
#
# Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
# log_odds = (policy_chosen_logps - policy_rejected_logps) - (
# torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps))
# )
#
# But here we use `x |-> log (1 + x)` instead of `torch.log1p` because it produces NaNs in BF16 on Neuron
# devices.

# We used this instead of `torch.log1p` because it produces NaNs in BF16.
def log1p(x):
return torch.log(1 + x)

Expand Down

0 comments on commit b0987a3

Please sign in to comment.