Vectorized insert and searchsort without vmap #28131
-
Hi All, Greetings! Just posting to ask if there is any native function that supports "insert along axis" or "search sorted along axis", without vmapping the existing The user case is that I have a sorted score array
Currently |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Hi - it's a good question! Actually I suspect the best way to write this would be something like all_scores = jnp.concatenate(scores, new_scores, axis=1)
all_values = jnp.concatenate(values, new_values, axis=1)
indices = jnp.argsort(all_scores, axis=1)
batch_indices = jnp.arange(indices.shape[0])[:, None]
updated_scores = all_scores[batch_indices, indices]
updated_values = all_values[batch_indices, indices] The reason for this is that XLA has no searchsorted primitive, so it can be somewhat slow in practice. And by doing static concatenation rather than dynamic insertion, all the operations will be JIT compatible. Hope that helps! |
Beta Was this translation helpful? Give feedback.
Hi - it's a good question! Actually I suspect the best way to write this would be something like
The reason for this is that XLA has no searchsorted primitive, so it can be somewhat slow in practice. And by doing static concatenation rather than dynamic insertion, all the operations will be JIT compatible.
Hope that helps!