Skip to content

Stack or shared memory version of dr.local? #397

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
sleepyeye opened this issue May 21, 2025 · 1 comment
Open

Stack or shared memory version of dr.local? #397

sleepyeye opened this issue May 21, 2025 · 1 comment

Comments

@sleepyeye
Copy link

sleepyeye commented May 21, 2025

Hi. I have a question, which is similar to #195 and #387.

Currently I'm trying to implement a kernel that process two large arrays X and Y, where the shape of each array is NxM and their element is Array4f (e.g., the number of floating points is NxMx4).
In the kernel, each thread process an Array4f from X and generates m x 4 (1<=m<=M) size of Array4f (each thread have different m depending on the value of each Array4f).

To implement such kernel, I utilized dr.local to allocate a scratch buffer with size M of Array4f and scatter_add multiple times to submit the result to Y.
This approach works but the kernel consumes quite large amount of memory depending on the size of N. Furthermore, the performance is poor (suspecting poor memory bandwidth due to large memory consumption or poor cache hit but not sure yet).

Since I know that a group of M threads access the same region of X and Y, I can obviously reduce the memory footprint of the kernel and increase the performance if I can some put the scratch buffer allocated by dr.local_alloc in the stack of each thread or shared memory.
Moreover, in this way, I can exploit block-wise or wrap level reduction to achieve higher performance.
What I have described can be simply written in CUDA as follows

// M is reasonably small number < 256
struct vec4f { float data[4]; };
constexpr int M = 32

__device__
void f(...);

__global__
void foo(int N, vec4f *X, vec4f *Y)
{
  int i = blockIdx.x*blockDim.x + threadIdx.x;
  int j = ...; /// index in block or wrap
  
  vec4f x = X[i];
  vec4f buffer_y[M] = {0.f};
  bool y_updated[M] = {false};
  
  // do some complex calculation
  // this function update m (1<= m <= M) elements of vec4f in buffer_y depending on the value of x.
  f(buffer_y, y_updated, x);  

  // block-wise or wrap-level reduction and sync `buffer_y` and `y_updated`
  // basically it just perform block-wise sum or wrap-level sum and 
  // sync all buffer_y in the same block
  reduction(buffer_y, y_updated);

  // submit the result
  if(y_updated[j])
    Y[i] += buffer_y[j];
}

Unfortunately, I couldn't find any way with python API to do this kind of low level control in drjit. Or am I missing something?
Is there any way to declare a small array of Array4f in the stack or shared memory of generated kernel?

Lastly, if I have to write my own kernel in CUDA, is there any easy way to integrate it to existing drjit application (Python side)?

@njroussel
Copy link
Member

njroussel commented May 22, 2025

Hi @sleepyeye

I'm not a 100% sure I fully understood your example, but I think I can still address your questions:

Is there any way to declare a small array of Array4f in the stack or shared memory of generated kernel?

The dr.Local[T] type uses the "local" state space: it's still allocated on the stack, i.e global (device) memory.
As for shared memory, Dr.Jit doesn't have any mechanism to explicitly handle/allocate shared memory at all.

Lastly, if I have to write my own kernel in CUDA, is there any easy way to integrate it to existing drjit application (Python side)?

No, it does not. However, you can use something like Pytorch's CUDA inlining, or write your own binding for a CUDA kernel with whatever tool you like. Dr.Jit will seemlessly interoperate with any device-stored datastructure that implements the dlpack protocol (for example: torch, jax, tf).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants