Hello,
while looking through your RoPE CUDA implementation (with an LLM), it pointed out a potential race condition in the shared memory access.
for (int h = 0; h < H; h++)
{
// then, load all the token for this head in shared memory
shared[threadIdx.x] = tokens[b][n][h][threadIdx.x];
__syncthreads();
const float u = shared[m];
const float v = shared[m+Q];
// write output
if ((threadIdx.x % (D/3)) < Q)
tokens[b][n][h][threadIdx.x] = u*cos - v*sin;
else
tokens[b][n][h][threadIdx.x] = v*cos + u*sin;
__syncthreads(); // <-- sync before overwriting shared in the next iteration
}
You already sync to ensure all threads have written their token value to shared, however in theory one thread could already write the next token value to shared before another thread has read their u, v from shared.
I am not entirely sure if this will actually materialize in practice, as the threads reading from shared should be in the same warp as the threads writing to shared[m]/shared[m+Q].
However, even intra-warp scheduling does not seem to be guaranteed anymore..
Or did I completely miss something here?
Apart from that, it also suggested adding the CHECK_KERNEL(); call, and add a check to ensure pos is actually of type i64.
TORCH_CHECK(pos.scalar_type() == torch::kInt64, "pos must be an int64 tensor");
Hello,
while looking through your RoPE CUDA implementation (with an LLM), it pointed out a potential race condition in the shared memory access.
You already sync to ensure all threads have written their token value to
shared, however in theory one thread could already write the next token value tosharedbefore another thread has read theiru, vfromshared.I am not entirely sure if this will actually materialize in practice, as the threads reading from
sharedshould be in the same warp as the threads writing toshared[m]/shared[m+Q].However, even intra-warp scheduling does not seem to be guaranteed anymore..
Or did I completely miss something here?
Apart from that, it also suggested adding the
CHECK_KERNEL();call, and add a check to ensureposis actually of type i64.TORCH_CHECK(pos.scalar_type() == torch::kInt64, "pos must be an int64 tensor");