Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some efficient method for permuting elements of an NDArray to get a new NDArray #215

Open
hakanai opened this issue Jan 1, 2025 · 1 comment
Labels
api Common api enhancement New feature or request

Comments

@hakanai
Copy link

hakanai commented Jan 1, 2025

One of the things we have in Vector API and which is also very common in shader code, is the ability to permute the indices of a vector to get a new vector. In computer graphics, we tend to call this "swizzling".

float4 a = float4(1, 2, 3, 4);
float4 b = a.wwxy;     // b now contains {4,4,1,2}

One use case I found this to help was vector cross products. Of course, in shader code, you have the cross product as a primitive, but not so lucky in Kotlin, so I hacked something up for Vector API, working something like this:

import jdk.incubator.vector.DoubleVector
import jdk.incubator.vector.VectorShuffle

private val yzxwShuffle = VectorShuffle.fromValues(DoubleVector.SPECIES_256, 1, 2, 0, 3)

private val zxywShuffle = VectorShuffle.fromValues(DoubleVector.SPECIES_256, 2, 0, 1, 3)

val DoubleVector.yzxw: DoubleVector
    get() = rearrange(yzxwShuffle)

val DoubleVector.zxyw: DoubleVector
    get() = rearrange(zxywShuffle)

// implementing vector cross product
val resultCells = (cells.yzxw * their.cells.zxyw) - (cells.zxyw * their.cells.yzxw)

I'm not sure if there is a better performance option for this one, but I was pretty happy with how I got to use the same pair of swizzles for both sides of the operation, just swapping which cells they were being applied to. I assume that these vector shuffle operations are fairly cheap to perform even on CPUs, because I was able to speed up my cross product quite a bit using this.

So I would like access to something similar in NDArray, or at least in D1Array. I don't have a good vision of what something like this would look like for higher dimension arrays. I thought I might be able to use slice() to get what I want, but it looks like each dimension you pass to that method can only return either one cell or a range of cells. I would like to pass a list of indices which may not be in order.

@devcrocod devcrocod added enhancement New feature or request api Common api labels Jan 8, 2025
@devcrocod
Copy link
Collaborator

Thank you for the suggestion!
I would like to note that this method might not provide good performance due to the specifics of how the JVM operates. Nevertheless, I think such a method would be useful. I will look into what the API and implementing for it could look like for nd arrays.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
api Common api enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants