From 499834a4704c9d9a0b36a91d75e1c11d33e1395a Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Mon, 21 Jul 2025 17:19:27 +0100 Subject: [PATCH 01/10] feat: add CUDA scan implementation with thrust --- torchlpc/csrc/cuda/linear_recurrence.cu | 291 ------------------------ torchlpc/csrc/cuda/scan.cu | 116 ++++++++++ 2 files changed, 116 insertions(+), 291 deletions(-) delete mode 100644 torchlpc/csrc/cuda/linear_recurrence.cu create mode 100644 torchlpc/csrc/cuda/scan.cu diff --git a/torchlpc/csrc/cuda/linear_recurrence.cu b/torchlpc/csrc/cuda/linear_recurrence.cu deleted file mode 100644 index a7c7bb9..0000000 --- a/torchlpc/csrc/cuda/linear_recurrence.cu +++ /dev/null @@ -1,291 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#define CEIL_DIV(x, y) ((x + y - 1) / y) - -#define gpuErrChk(ans) \ - { \ - gpuAssert((ans), __FILE__, __LINE__); \ - } -void gpuAssert(cudaError_t code, const char *file, int line) { - if (code != cudaSuccess) { - fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, - line); - } -} - -__device__ int2 divide_work(int n_jobs, int n_workers, int worker_idx) { - // Each worker will do a continuous slice of either n_jobs / n_workers - // or ceil_div(n_jobs, n_workers). The return value is an int2 representing - // a half open interval of jobs for the worker to perform (perform jobs - // i for a <= i < b) - - int cd = CEIL_DIV(n_jobs, n_workers); - int d = n_jobs / n_workers; - - int doing_cd = n_jobs % n_workers; - - int2 retval; - if (worker_idx < doing_cd) { - retval.x = worker_idx * cd; - retval.y = retval.x + cd; - } else { - retval.x = doing_cd * cd + (worker_idx - doing_cd) * d; - retval.y = retval.x + d; - } - - return retval; -} - -__device__ int2 compute_warp_start_stop(int block_idx, int warp_idx, - int n_blocks, int n_steps) { - int2 block_ss = divide_work(n_steps, n_blocks, block_idx); - int block_start = block_ss.x; - int block_stop = block_ss.y; - int block_jobs = block_stop - block_start; - - int2 warp_ss = divide_work(block_jobs, 32, warp_idx); - int warp_start = block_start + warp_ss.x; - int warp_stop = block_start + warp_ss.y; - - int2 retval; - retval.x = warp_start; - retval.y = warp_stop; - return retval; -} - -// decay storage, h_storage: -// each a n_dims x 33 x n_blocks matrix on GPU with 33rd column for block -// reduction -template -__global__ void reduction_kernel(const scalar_t *decays, - const scalar_t *impulses, - const scalar_t *initial_state, - scalar_t *_decay_storage, scalar_t *_h_storage, - int n_dims, int n_steps) { - int warp = threadIdx.x / 32; - int lane = threadIdx.x % 32; - - scalar_t *decay_storage = &_decay_storage[blockIdx.x * 33 * n_dims]; - scalar_t *h_storage = &_h_storage[blockIdx.x * 33 * n_dims]; - - int2 start_stop = - compute_warp_start_stop(blockIdx.x, lane, gridDim.x, n_steps); - int warp_start = start_stop.x; - int warp_stop = start_stop.y; - - /* - * Reduce within warps. - * After this loop exits, the storage arrays should contain the reduction - * from warp_start to warp_stop (including initial state) at index - * (feature_idx, warp, block). - */ - for (int i = warp; i < n_dims; i += CEIL_DIV(blockDim.x, 32)) { - scalar_t cum_decay = static_cast(1.0); - scalar_t h = static_cast(0.0); - if (blockIdx.x == 0 && lane == 0 && initial_state != NULL) { - h = initial_state[i]; - } - - for (int t = warp_start; t < warp_stop; t++) { - cum_decay *= decays[i * n_steps + t]; - h = decays[i * n_steps + t] * h + impulses[i * n_steps + t]; - } - - // TODO: store into shared memory, work in shared memory sized blocks - // store into global memory - decay_storage[i + lane * n_dims] = cum_decay; - h_storage[i + lane * n_dims] = h; - } - - __syncthreads(); - - /* - * Reduce over warps. - * After this loop exits, the storage arrays should contain the reduction - * from block_start to block_finish (including initial state) at index - * (feature_idx, 32, block). - */ - // TODO: parallel reduction (or scan). Need to worry about changing the warp - // reduction values (as I use them again later) - for (int i = threadIdx.x; i < n_dims; i += blockDim.x) { - scalar_t cum_decay = static_cast(1.0); - scalar_t h = static_cast(0.0); - for (int t = 0; t < 32; t++) { - cum_decay *= decay_storage[i + t * n_dims]; - h = decay_storage[i + t * n_dims] * h + h_storage[i + t * n_dims]; - } - decay_storage[i + 32 * n_dims] = cum_decay; - h_storage[i + 32 * n_dims] = h; - } -} - -template -__global__ void block_scan_kernel(scalar_t *decay_storage, scalar_t *h_storage, - int n_dims, int n_blocks) { - /* - * Scan over blocks. - * After this loop exits, the storage arrays should contain the cumulative - * sum from block_idx 0 to i (inclusive) at index (feature_idx, 32, i) This - * means (feature_idx, 32, 2) contains the reduction of blocks 0, 1, and 2. - */ - // TODO: parallel scan (tricky because number of blocks isn't necessarily - // smaller than number of warps that can fit in a single block) - for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < n_dims; - i += blockDim.x * gridDim.x) { - for (int t = 1; t < n_blocks; t++) { - int cur_idx = i + 32 * n_dims + t * 33 * n_dims; - int prev_idx = i + 32 * n_dims + (t - 1) * 33 * n_dims; - - // TODO: remove unneccessary reads from global memory (prev_idx - // accesses) - h_storage[cur_idx] = decay_storage[cur_idx] * h_storage[prev_idx] + - h_storage[cur_idx]; - decay_storage[cur_idx] *= decay_storage[prev_idx]; - } - } -} - -template -__global__ void warp_scan_kernel(const scalar_t *decays, - const scalar_t *impulses, - const scalar_t *initial_state, scalar_t *out, - scalar_t *decay_storage, scalar_t *h_storage, - int n_dims, int n_steps) { - int warp = threadIdx.x / 32; - int lane = threadIdx.x % 32; - - // Note: Due to the index ordering of the storage arrays, the following - // indices are equivalent: - // - // i + (t - 1) * n_dims + blockIdx.x * 33 * n_dims - // i + 32 * n_dims + (blockIdx.x - 1) * 33 * n_dims - // - // when t is 0. This means something that looks like negative indexing - // (t-1) can be used to safely access the stored value for the previous - // warp (even if the previous warp belonged to the previous block). - - /* - * Scan over warps. - * After this loop executes, the storage arrays should contain the - * cumulative sum from the beginning of sequence (including initial - * condition) up to and including the indexed warp and block. - */ - // TODO: parallel scan - for (int i = threadIdx.x; i < n_dims; i += blockDim.x) { - for (int t = 0; t < 32; t++) { - if (t == 0 && blockIdx.x == 0) { - // the reduction over warp 0 (including initial condition) is - // correct val for scan, so there's no work to do - continue; - } - - int cur_idx = i + t * n_dims + blockIdx.x * 33 * n_dims; - int prev_idx = i + (t - 1) * n_dims + blockIdx.x * 33 * n_dims; - h_storage[cur_idx] = decay_storage[cur_idx] * h_storage[prev_idx] + - h_storage[cur_idx]; - decay_storage[cur_idx] *= decay_storage[prev_idx]; - } - } - - __syncthreads(); - - int2 start_stop = - compute_warp_start_stop(blockIdx.x, lane, gridDim.x, n_steps); - int warp_start = start_stop.x; - int warp_stop = start_stop.y; - - /* - * Scan within warps. - * This loop writes to the output array. Each warp reads in it's initial - * state (either from the "initial_state" or the storage arrays) and then - * writes to output for indices warp_start up to warp_stop. - */ - for (int i = warp; i < n_dims; i += CEIL_DIV(blockDim.x, 32)) { - scalar_t h = static_cast(0.0); - if (blockIdx.x == 0 && lane == 0) { - if (initial_state != NULL) { - h = initial_state[i]; - } - } else { - h = h_storage[i + (lane - 1) * n_dims + blockIdx.x * 33 * n_dims]; - } - - for (int t = warp_start; t < warp_stop; t++) { - h = decays[i * n_steps + t] * h + impulses[i * n_steps + t]; - out[i * n_steps + t] = h; - } - } -} - -/* - * This is the main method for the prefix sum kernels. - * decays, impulses, out: - * each a n_dims x n_steps column major matrix located on GPU - * initial_state: - * array of size n_dims located on GPU - */ -template -void compute_linear_recurrence(const scalar_t *decays, const scalar_t *impulses, - const scalar_t *initial_state, scalar_t *out, - int n_dims, int n_steps) { - // we want at least 32 elements per block, but no reason to run - // with more than the maximum number of concurrent blocks - // NOTE: 128 is decided empirically. - int n_blocks = min(CEIL_DIV(n_steps, 32), 128); - - // TODO: make user pass in working memory? This allows integration - // with CNMeM (used by Theano) - int reduction_mem_sz = 2 * n_blocks * 33 * n_dims * sizeof(scalar_t); - scalar_t *d_reduction_mem; - gpuErrChk(cudaMalloc(&d_reduction_mem, reduction_mem_sz)); - scalar_t *d_decay_storage = &d_reduction_mem[0 * n_blocks * 33 * n_dims]; - scalar_t *d_h_storage = &d_reduction_mem[1 * n_blocks * 33 * n_dims]; - - // TODO: run kernels on non-default stream? - reduction_kernel<<>>(decays, impulses, initial_state, - d_decay_storage, d_h_storage, n_dims, - n_steps); - - block_scan_kernel<<>>(d_decay_storage, d_h_storage, n_dims, - n_blocks); - - warp_scan_kernel<<>>(decays, impulses, initial_state, out, - d_decay_storage, d_h_storage, n_dims, - n_steps); - - gpuErrChk(cudaFree(d_reduction_mem)); -} - -at::Tensor scan_cuda_wrapper(const at::Tensor &input, const at::Tensor &weights, - const at::Tensor &initials) { - TORCH_CHECK(input.is_floating_point() || input.is_complex(), - "Input must be floating point or complex"); - TORCH_CHECK(initials.scalar_type() == input.scalar_type(), - "Initials must have the same scalar type as input"); - TORCH_CHECK(weights.scalar_type() == input.scalar_type(), - "Weights must have the same scalar type as input"); - - auto input_contiguous = input.contiguous(); - auto weights_contiguous = weights.contiguous(); - auto output = at::empty_like(input_contiguous); - - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( - input.scalar_type(), "compute_linear_recurrence", [&] { - compute_linear_recurrence( - weights_contiguous.const_data_ptr(), - input_contiguous.const_data_ptr(), - initials.const_data_ptr(), - output.mutable_data_ptr(), input_contiguous.size(0), - input_contiguous.size(1)); - }); - return output.contiguous(); -} - -TORCH_LIBRARY_IMPL(torchlpc, CUDA, m) { m.impl("scan", &scan_cuda_wrapper); } diff --git a/torchlpc/csrc/cuda/scan.cu b/torchlpc/csrc/cuda/scan.cu new file mode 100644 index 0000000..3f59771 --- /dev/null +++ b/torchlpc/csrc/cuda/scan.cu @@ -0,0 +1,116 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// template +// using scan_matrix = thrust::device_vector< +// cuda::std::pair, thrust::device_vector>>; + +// template +// cuda::std::pair, thrust::device_vector> +// recur_binary_op(const cuda::std::pair, +// thrust::device_vector> &a, +// const cuda::std::pair, +// thrust::device_vector> &b) { +// cuda::std::pair, thrust::device_vector> +// result; result.first.resize(a.first.size()); +// result.second.resize(a.second.size()); + +// ::cuda::std::multiplies mult_op; +// ::cuda::std::plus add_op; + +// thrust::transform(thrust::device, a.first.cbegin(), a.first.cend(), +// b.first.cbegin(), result.first.begin(), mult_op); +// thrust::transform(thrust::device, a.first.cbegin(), a.first.cend(), +// b.second.cbegin(), result.second.begin(), mult_op); +// thrust::transform(thrust::device, result.second.cbegin(), +// result.second.cend(), a.second.cbegin(), +// result.second.begin(), add_op); +// return result; +// } + +template +struct recur_binary_op { + __host__ __device__ cuda::std::pair operator()( + const cuda::std::pair &a, const cuda::std::pair &b) const { + return cuda::std::make_pair(a.first * b.first, + a.second * b.first + b.second); + } +}; + +template +void compute_linear_recurrence(const scalar_t *decays, const scalar_t *impulses, + const scalar_t *initial_state, scalar_t *out, + int n_steps) { + thrust::device_vector> input_states( + n_steps); + thrust::device_vector> output_states( + n_steps); + + // Initialize input_states and output_states + thrust::transform(thrust::device, decays, decays + n_steps, impulses, + input_states.begin(), + [=] __host__ __device__(const scalar_t &decay, + const scalar_t &impulse) { + return cuda::std::make_pair(decay, impulse); + }); + + // auto initial_state_pair = cuda::std::make_pair(0.0, initial_state[0]); + + recur_binary_op binary_op; + + thrust::inclusive_scan(thrust::device, input_states.begin(), + input_states.end(), output_states.begin(), + binary_op); + + thrust::transform(thrust::device, output_states.begin(), + output_states.end(), out, + [=] __host__ __device__( + const cuda::std::pair &state) { + // state + return state.second; + }); +} + +at::Tensor scan_cuda_wrapper(const at::Tensor &input, const at::Tensor &weights, + const at::Tensor &initials) { + TORCH_CHECK(input.is_floating_point() || input.is_complex(), + "Input must be floating point or complex"); + TORCH_CHECK(initials.scalar_type() == input.scalar_type(), + "Initials must have the same scalar type as input"); + TORCH_CHECK(weights.scalar_type() == input.scalar_type(), + "Weights must have the same scalar type as input"); + + // auto input_contiguous = input.contiguous(); + auto input_contiguous = + at::cat({initials.unsqueeze(1), input}, 1).contiguous(); + auto weights_contiguous = + at::cat({at::zeros_like(initials.unsqueeze(1)), weights}, 1) + .contiguous(); + auto output = at::empty_like(input_contiguous); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( + input.scalar_type(), "compute_linear_recurrence", [&] { + compute_linear_recurrence( + weights_contiguous.const_data_ptr(), + input_contiguous.const_data_ptr(), + initials.const_data_ptr(), + output.mutable_data_ptr(), input_contiguous.numel()); + }); + return output.slice(1, 1, output.size(1)) + .contiguous(); // Remove the initial state from the output +} + +TORCH_LIBRARY_IMPL(torchlpc, CUDA, m) { m.impl("scan", &scan_cuda_wrapper); } \ No newline at end of file From 735a58ceabf7f0ffd482a943cfecfffd0dd35d3a Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Mon, 21 Jul 2025 17:19:49 +0100 Subject: [PATCH 02/10] fix: add missing nvcc compile argument --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 7ebd99a..ebdb879 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,7 @@ def get_extensions(): if use_cuda: sources += cuda_sources + extra_compile_args["nvcc"] = ["--extended-lambda"] ext_modules = [ extension( From 852026bd54c11c245af6aaf8be190961212ac69d Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Mon, 21 Jul 2025 17:34:52 +0100 Subject: [PATCH 03/10] refactor --- torchlpc/csrc/cuda/scan.cu | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/torchlpc/csrc/cuda/scan.cu b/torchlpc/csrc/cuda/scan.cu index 3f59771..b3fa95d 100644 --- a/torchlpc/csrc/cuda/scan.cu +++ b/torchlpc/csrc/cuda/scan.cu @@ -52,30 +52,24 @@ template void compute_linear_recurrence(const scalar_t *decays, const scalar_t *impulses, const scalar_t *initial_state, scalar_t *out, int n_steps) { - thrust::device_vector> input_states( - n_steps); - thrust::device_vector> output_states( - n_steps); + thrust::device_vector> pairs(n_steps); // Initialize input_states and output_states - thrust::transform(thrust::device, decays, decays + n_steps, impulses, - input_states.begin(), - [=] __host__ __device__(const scalar_t &decay, - const scalar_t &impulse) { - return cuda::std::make_pair(decay, impulse); - }); + thrust::transform( + thrust::device, decays, decays + n_steps, impulses, pairs.begin(), + [] __host__ __device__(const scalar_t &decay, const scalar_t &impulse) { + return cuda::std::make_pair(decay, impulse); + }); // auto initial_state_pair = cuda::std::make_pair(0.0, initial_state[0]); recur_binary_op binary_op; - thrust::inclusive_scan(thrust::device, input_states.begin(), - input_states.end(), output_states.begin(), - binary_op); + thrust::inclusive_scan(thrust::device, pairs.begin(), pairs.end(), + pairs.begin(), binary_op); - thrust::transform(thrust::device, output_states.begin(), - output_states.end(), out, - [=] __host__ __device__( + thrust::transform(thrust::device, pairs.begin(), pairs.end(), out, + [] __host__ __device__( const cuda::std::pair &state) { // state return state.second; From 08984797a2de903a3d0ba8a53ee3259b7752a1fb Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Mon, 21 Jul 2025 18:29:34 +0100 Subject: [PATCH 04/10] refactor: remove unused argument --- torchlpc/csrc/cuda/scan.cu | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchlpc/csrc/cuda/scan.cu b/torchlpc/csrc/cuda/scan.cu index b3fa95d..23e1b93 100644 --- a/torchlpc/csrc/cuda/scan.cu +++ b/torchlpc/csrc/cuda/scan.cu @@ -50,8 +50,7 @@ struct recur_binary_op { template void compute_linear_recurrence(const scalar_t *decays, const scalar_t *impulses, - const scalar_t *initial_state, scalar_t *out, - int n_steps) { + scalar_t *out, int n_steps) { thrust::device_vector> pairs(n_steps); // Initialize input_states and output_states @@ -100,7 +99,6 @@ at::Tensor scan_cuda_wrapper(const at::Tensor &input, const at::Tensor &weights, compute_linear_recurrence( weights_contiguous.const_data_ptr(), input_contiguous.const_data_ptr(), - initials.const_data_ptr(), output.mutable_data_ptr(), input_contiguous.numel()); }); return output.slice(1, 1, output.size(1)) From 19dfad5c98af5e9bdf3b8d20e830ce8e098ce7ab Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Tue, 22 Jul 2025 14:33:42 +0100 Subject: [PATCH 05/10] try for_each to increase parallelism --- torchlpc/csrc/cuda/scan.cu | 78 +++++++++++++++++++++++--------------- 1 file changed, 47 insertions(+), 31 deletions(-) diff --git a/torchlpc/csrc/cuda/scan.cu b/torchlpc/csrc/cuda/scan.cu index 23e1b93..c51eae4 100644 --- a/torchlpc/csrc/cuda/scan.cu +++ b/torchlpc/csrc/cuda/scan.cu @@ -12,33 +12,6 @@ #include #include -// template -// using scan_matrix = thrust::device_vector< -// cuda::std::pair, thrust::device_vector>>; - -// template -// cuda::std::pair, thrust::device_vector> -// recur_binary_op(const cuda::std::pair, -// thrust::device_vector> &a, -// const cuda::std::pair, -// thrust::device_vector> &b) { -// cuda::std::pair, thrust::device_vector> -// result; result.first.resize(a.first.size()); -// result.second.resize(a.second.size()); - -// ::cuda::std::multiplies mult_op; -// ::cuda::std::plus add_op; - -// thrust::transform(thrust::device, a.first.cbegin(), a.first.cend(), -// b.first.cbegin(), result.first.begin(), mult_op); -// thrust::transform(thrust::device, a.first.cbegin(), a.first.cend(), -// b.second.cbegin(), result.second.begin(), mult_op); -// thrust::transform(thrust::device, result.second.cbegin(), -// result.second.cend(), a.second.cbegin(), -// result.second.begin(), add_op); -// return result; -// } - template struct recur_binary_op { __host__ __device__ cuda::std::pair operator()( @@ -48,6 +21,17 @@ struct recur_binary_op { } }; +template +struct scan_functor { + thrust::device_ptr> data; + int n_steps; + __host__ __device__ void operator()(int i) const { + thrust::inclusive_scan(thrust::device, data + i * n_steps, + data + (i + 1) * n_steps, data + i * n_steps, + recur_binary_op()); + } +}; + template void compute_linear_recurrence(const scalar_t *decays, const scalar_t *impulses, scalar_t *out, int n_steps) { @@ -70,7 +54,33 @@ void compute_linear_recurrence(const scalar_t *decays, const scalar_t *impulses, thrust::transform(thrust::device, pairs.begin(), pairs.end(), out, [] __host__ __device__( const cuda::std::pair &state) { - // state + return state.second; + }); +} + +template +void compute_linear_recurrence2(const scalar_t *decays, + const scalar_t *impulses, + // const scalar_t *initials, + scalar_t *out, int n_dims, int n_steps) { + thrust::device_vector> pairs(n_steps * + n_dims); + thrust::transform( + thrust::device, decays, decays + n_steps * n_dims, impulses, + pairs.begin(), + [] __host__ __device__(const scalar_t &decay, const scalar_t &impulse) { + return cuda::std::make_pair(decay, impulse); + }); + + recur_binary_op binary_op; + thrust::counting_iterator it(0); + scan_functor scan_op{pairs.data(), n_steps}; + + thrust::for_each(thrust::device, it, it + n_dims, scan_op); + + thrust::transform(thrust::device, pairs.begin(), pairs.end(), out, + [] __host__ __device__( + const cuda::std::pair &state) { return state.second; }); } @@ -84,7 +94,6 @@ at::Tensor scan_cuda_wrapper(const at::Tensor &input, const at::Tensor &weights, TORCH_CHECK(weights.scalar_type() == input.scalar_type(), "Weights must have the same scalar type as input"); - // auto input_contiguous = input.contiguous(); auto input_contiguous = at::cat({initials.unsqueeze(1), input}, 1).contiguous(); auto weights_contiguous = @@ -96,10 +105,17 @@ at::Tensor scan_cuda_wrapper(const at::Tensor &input, const at::Tensor &weights, AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( input.scalar_type(), "compute_linear_recurrence", [&] { - compute_linear_recurrence( + // compute_linear_recurrence( + // weights_contiguous.const_data_ptr(), + // input_contiguous.const_data_ptr(), + // output.mutable_data_ptr(), + // input_contiguous.numel()); + compute_linear_recurrence2( weights_contiguous.const_data_ptr(), input_contiguous.const_data_ptr(), - output.mutable_data_ptr(), input_contiguous.numel()); + // initials.const_data_ptr(), + output.mutable_data_ptr(), input_contiguous.size(0), + input_contiguous.size(1)); }); return output.slice(1, 1, output.size(1)) .contiguous(); // Remove the initial state from the output From 4d9b29bb950aedc087701faeb31e27e03c56c280 Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Tue, 22 Jul 2025 14:46:26 +0100 Subject: [PATCH 06/10] fallback to using flatten long scan --- torchlpc/csrc/cuda/scan.cu | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/torchlpc/csrc/cuda/scan.cu b/torchlpc/csrc/cuda/scan.cu index c51eae4..4be6b1b 100644 --- a/torchlpc/csrc/cuda/scan.cu +++ b/torchlpc/csrc/cuda/scan.cu @@ -105,17 +105,16 @@ at::Tensor scan_cuda_wrapper(const at::Tensor &input, const at::Tensor &weights, AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( input.scalar_type(), "compute_linear_recurrence", [&] { - // compute_linear_recurrence( + compute_linear_recurrence( + weights_contiguous.const_data_ptr(), + input_contiguous.const_data_ptr(), + output.mutable_data_ptr(), input_contiguous.numel()); + // compute_linear_recurrence2( // weights_contiguous.const_data_ptr(), // input_contiguous.const_data_ptr(), + // // initials.const_data_ptr(), // output.mutable_data_ptr(), - // input_contiguous.numel()); - compute_linear_recurrence2( - weights_contiguous.const_data_ptr(), - input_contiguous.const_data_ptr(), - // initials.const_data_ptr(), - output.mutable_data_ptr(), input_contiguous.size(0), - input_contiguous.size(1)); + // input_contiguous.size(0), input_contiguous.size(1)); }); return output.slice(1, 1, output.size(1)) .contiguous(); // Remove the initial state from the output From a60cfe7d1a29033d2064d0f475ad7e1ab60c0bc4 Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Wed, 23 Jul 2025 20:04:24 +0100 Subject: [PATCH 07/10] refactor: enhance scan operations with transform input/output iterators --- torchlpc/csrc/cuda/scan.cu | 56 ++++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 24 deletions(-) diff --git a/torchlpc/csrc/cuda/scan.cu b/torchlpc/csrc/cuda/scan.cu index 4be6b1b..0ca701a 100644 --- a/torchlpc/csrc/cuda/scan.cu +++ b/torchlpc/csrc/cuda/scan.cu @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -14,10 +15,31 @@ template struct recur_binary_op { + __host__ __device__ cuda::std::tuple operator()( + const cuda::std::tuple &a, + const cuda::std::tuple &b) const { + auto a_first = thrust::get<0>(a); + auto a_second = thrust::get<1>(a); + auto b_first = thrust::get<0>(b); + auto b_second = thrust::get<1>(b); + return cuda::std::make_tuple(a_first * b_first, + a_second * b_first + b_second); + } +}; + +template +struct input_unary_op { __host__ __device__ cuda::std::pair operator()( - const cuda::std::pair &a, const cuda::std::pair &b) const { - return cuda::std::make_pair(a.first * b.first, - a.second * b.first + b.second); + const T &decay, const T &impulse) const { + return cuda::std::make_pair(decay, impulse); + } +}; + +template +struct output_unary_op { + __host__ __device__ T + operator()(const cuda::std::tuple &state) const { + return thrust::get<1>(state); } }; @@ -35,27 +57,13 @@ struct scan_functor { template void compute_linear_recurrence(const scalar_t *decays, const scalar_t *impulses, scalar_t *out, int n_steps) { - thrust::device_vector> pairs(n_steps); - - // Initialize input_states and output_states - thrust::transform( - thrust::device, decays, decays + n_steps, impulses, pairs.begin(), - [] __host__ __device__(const scalar_t &decay, const scalar_t &impulse) { - return cuda::std::make_pair(decay, impulse); - }); - - // auto initial_state_pair = cuda::std::make_pair(0.0, initial_state[0]); - - recur_binary_op binary_op; - - thrust::inclusive_scan(thrust::device, pairs.begin(), pairs.end(), - pairs.begin(), binary_op); - - thrust::transform(thrust::device, pairs.begin(), pairs.end(), out, - [] __host__ __device__( - const cuda::std::pair &state) { - return state.second; - }); + thrust::inclusive_scan( + thrust::device, thrust::make_zip_iterator(decays, impulses), + thrust::make_zip_iterator(decays + n_steps, impulses + n_steps), + thrust::make_transform_output_iterator(out, + // thrust::get<1>), + output_unary_op()), + recur_binary_op()); } template From d1d87322a6b549b5378919793c8d813c610303f3 Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Wed, 23 Jul 2025 21:04:07 +0100 Subject: [PATCH 08/10] refactor recur2 version --- torchlpc/csrc/cuda/scan.cu | 57 ++++++++++++-------------------------- 1 file changed, 17 insertions(+), 40 deletions(-) diff --git a/torchlpc/csrc/cuda/scan.cu b/torchlpc/csrc/cuda/scan.cu index 0ca701a..d4fc6f7 100644 --- a/torchlpc/csrc/cuda/scan.cu +++ b/torchlpc/csrc/cuda/scan.cu @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -27,14 +28,6 @@ struct recur_binary_op { } }; -template -struct input_unary_op { - __host__ __device__ cuda::std::pair operator()( - const T &decay, const T &impulse) const { - return cuda::std::make_pair(decay, impulse); - } -}; - template struct output_unary_op { __host__ __device__ T @@ -43,54 +36,38 @@ struct output_unary_op { } }; -template -struct scan_functor { - thrust::device_ptr> data; - int n_steps; - __host__ __device__ void operator()(int i) const { - thrust::inclusive_scan(thrust::device, data + i * n_steps, - data + (i + 1) * n_steps, data + i * n_steps, - recur_binary_op()); - } -}; - template -void compute_linear_recurrence(const scalar_t *decays, const scalar_t *impulses, - scalar_t *out, int n_steps) { +__host__ __device__ void compute_linear_recurrence(const scalar_t *decays, + const scalar_t *impulses, + scalar_t *out, int n_steps) { thrust::inclusive_scan( thrust::device, thrust::make_zip_iterator(decays, impulses), thrust::make_zip_iterator(decays + n_steps, impulses + n_steps), thrust::make_transform_output_iterator(out, - // thrust::get<1>), output_unary_op()), recur_binary_op()); } +template +struct scan_functor { + const T *decays, *impulses; + T *out; + int n_steps; + __host__ __device__ void operator()(int i) { + compute_linear_recurrence(decays + i * n_steps, + impulses + i * n_steps, out + i * n_steps, + n_steps); + } +}; + template void compute_linear_recurrence2(const scalar_t *decays, const scalar_t *impulses, // const scalar_t *initials, scalar_t *out, int n_dims, int n_steps) { - thrust::device_vector> pairs(n_steps * - n_dims); - thrust::transform( - thrust::device, decays, decays + n_steps * n_dims, impulses, - pairs.begin(), - [] __host__ __device__(const scalar_t &decay, const scalar_t &impulse) { - return cuda::std::make_pair(decay, impulse); - }); - - recur_binary_op binary_op; thrust::counting_iterator it(0); - scan_functor scan_op{pairs.data(), n_steps}; - + scan_functor scan_op{decays, impulses, out, n_steps}; thrust::for_each(thrust::device, it, it + n_dims, scan_op); - - thrust::transform(thrust::device, pairs.begin(), pairs.end(), out, - [] __host__ __device__( - const cuda::std::pair &state) { - return state.second; - }); } at::Tensor scan_cuda_wrapper(const at::Tensor &input, const at::Tensor &weights, From 25f8bb19343f0b6cf64c571b02d14d3d3020680a Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Wed, 23 Jul 2025 21:36:21 +0100 Subject: [PATCH 09/10] refactor: add support for experimental library compilation on macOS --- setup.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/setup.py b/setup.py index ebdb879..6dff2f7 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ import setuptools import os import glob +from platform import system import torch from torch.utils.cpp_extension import ( CppExtension, @@ -36,6 +37,13 @@ def get_extensions(): extra_compile_args["cxx"] = ["-fopenmp"] extra_link_args.append("-fopenmp") + if system() == "Darwin": + extra_compile_args["cxx"] = ( + ["-fexperimental-library"] + if "cxx" not in extra_compile_args + else extra_compile_args["cxx"] + ["-fexperimental-library"] + ) + this_dir = os.path.abspath(os.path.dirname(__file__)) extensions_dir = os.path.join(this_dir, library_name, "csrc") sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp"))) From 82b415930844f36e8d1d36610da8ff724d035c42 Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Thu, 24 Jul 2025 14:27:07 +0100 Subject: [PATCH 10/10] ci: update macOS runner and clang++ version for newer c++ standard --- .github/workflows/python-package.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index c25e3f9..9a16185 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -46,7 +46,7 @@ jobs: build-macos: if: github.event_name == 'pull_request' - runs-on: macos-latest + runs-on: macos-15 strategy: fail-fast: false matrix: @@ -74,7 +74,7 @@ jobs: flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Build CPP extension with clang++ run: | - export CXX=$(brew --prefix llvm@15)/bin/clang++ + export CXX=$(brew --prefix llvm@18)/bin/clang++ export LDFLAGS="-L/usr/local/opt/libomp/lib" export CPPFLAGS="-I/usr/local/opt/libomp/include" python -m pip install -e .