Skip to content

Logit processing in mlx #1649

Answered by awni
chimezie asked this question in Q&A
Dec 5, 2024 · 1 comments · 3 replies

You must be logged in to vote

PyTorch's gather method makes this very trivial, but there is no equivalent in mlx for design reasons.

We have fancy indexing in MLX which should be as expressive as gather. But maybe I'm misunderstanding what you are trying to do.

Your second function (get_3d_target_scores) looks ok to me (I don't quite get what it's doing at the end where you sum over the sequence length.. you can simplify it a bit which will also make it a bit faster:

def get_3d_target_scores(batch_scores, target_tokens):
    full_seq_len = batch_scores.shape[1]
    target_seq_size = max(len(i) for i in target_tokens)
    arr = mx.take_along_axis(
        batch_scores,
        mx.array([[0] * (batch_scores.shape[1] - l…

Replies: 1 comment 3 replies

You must be logged in to vote
3 replies
@chimezie

@awni

awni Dec 7, 2024
Maintainer

@chimezie

Answer selected by chimezie
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants