Skip to content

Vectorized insert and searchsort without vmap #28131

Answered by jakevdp
Methylamphetamine asked this question in Q&A
Discussion options

You must be logged in to vote

Hi - it's a good question! Actually I suspect the best way to write this would be something like

all_scores = jnp.concatenate(scores, new_scores, axis=1)
all_values = jnp.concatenate(values, new_values, axis=1)

indices = jnp.argsort(all_scores, axis=1)
batch_indices = jnp.arange(indices.shape[0])[:, None]

updated_scores = all_scores[batch_indices, indices]
updated_values = all_values[batch_indices, indices]

The reason for this is that XLA has no searchsorted primitive, so it can be somewhat slow in practice. And by doing static concatenation rather than dynamic insertion, all the operations will be JIT compatible.

Hope that helps!

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@Methylamphetamine
Comment options

@jakevdp
Comment options

@Methylamphetamine
Comment options

Answer selected by Methylamphetamine
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants