diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index c25e3f9..9a16185 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -46,7 +46,7 @@ jobs: build-macos: if: github.event_name == 'pull_request' - runs-on: macos-latest + runs-on: macos-15 strategy: fail-fast: false matrix: @@ -74,7 +74,7 @@ jobs: flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Build CPP extension with clang++ run: | - export CXX=$(brew --prefix llvm@15)/bin/clang++ + export CXX=$(brew --prefix llvm@18)/bin/clang++ export LDFLAGS="-L/usr/local/opt/libomp/lib" export CPPFLAGS="-I/usr/local/opt/libomp/include" python -m pip install -e . diff --git a/setup.py b/setup.py index 7ebd99a..6dff2f7 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ import setuptools import os import glob +from platform import system import torch from torch.utils.cpp_extension import ( CppExtension, @@ -36,6 +37,13 @@ def get_extensions(): extra_compile_args["cxx"] = ["-fopenmp"] extra_link_args.append("-fopenmp") + if system() == "Darwin": + extra_compile_args["cxx"] = ( + ["-fexperimental-library"] + if "cxx" not in extra_compile_args + else extra_compile_args["cxx"] + ["-fexperimental-library"] + ) + this_dir = os.path.abspath(os.path.dirname(__file__)) extensions_dir = os.path.join(this_dir, library_name, "csrc") sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp"))) @@ -45,6 +53,7 @@ def get_extensions(): if use_cuda: sources += cuda_sources + extra_compile_args["nvcc"] = ["--extended-lambda"] ext_modules = [ extension( diff --git a/torchlpc/csrc/cuda/linear_recurrence.cu b/torchlpc/csrc/cuda/linear_recurrence.cu deleted file mode 100644 index a7c7bb9..0000000 --- a/torchlpc/csrc/cuda/linear_recurrence.cu +++ /dev/null @@ -1,291 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#define CEIL_DIV(x, y) ((x + y - 1) / y) - -#define gpuErrChk(ans) \ - { \ - gpuAssert((ans), __FILE__, __LINE__); \ - } -void gpuAssert(cudaError_t code, const char *file, int line) { - if (code != cudaSuccess) { - fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, - line); - } -} - -__device__ int2 divide_work(int n_jobs, int n_workers, int worker_idx) { - // Each worker will do a continuous slice of either n_jobs / n_workers - // or ceil_div(n_jobs, n_workers). The return value is an int2 representing - // a half open interval of jobs for the worker to perform (perform jobs - // i for a <= i < b) - - int cd = CEIL_DIV(n_jobs, n_workers); - int d = n_jobs / n_workers; - - int doing_cd = n_jobs % n_workers; - - int2 retval; - if (worker_idx < doing_cd) { - retval.x = worker_idx * cd; - retval.y = retval.x + cd; - } else { - retval.x = doing_cd * cd + (worker_idx - doing_cd) * d; - retval.y = retval.x + d; - } - - return retval; -} - -__device__ int2 compute_warp_start_stop(int block_idx, int warp_idx, - int n_blocks, int n_steps) { - int2 block_ss = divide_work(n_steps, n_blocks, block_idx); - int block_start = block_ss.x; - int block_stop = block_ss.y; - int block_jobs = block_stop - block_start; - - int2 warp_ss = divide_work(block_jobs, 32, warp_idx); - int warp_start = block_start + warp_ss.x; - int warp_stop = block_start + warp_ss.y; - - int2 retval; - retval.x = warp_start; - retval.y = warp_stop; - return retval; -} - -// decay storage, h_storage: -// each a n_dims x 33 x n_blocks matrix on GPU with 33rd column for block -// reduction -template -__global__ void reduction_kernel(const scalar_t *decays, - const scalar_t *impulses, - const scalar_t *initial_state, - scalar_t *_decay_storage, scalar_t *_h_storage, - int n_dims, int n_steps) { - int warp = threadIdx.x / 32; - int lane = threadIdx.x % 32; - - scalar_t *decay_storage = &_decay_storage[blockIdx.x * 33 * n_dims]; - scalar_t *h_storage = &_h_storage[blockIdx.x * 33 * n_dims]; - - int2 start_stop = - compute_warp_start_stop(blockIdx.x, lane, gridDim.x, n_steps); - int warp_start = start_stop.x; - int warp_stop = start_stop.y; - - /* - * Reduce within warps. - * After this loop exits, the storage arrays should contain the reduction - * from warp_start to warp_stop (including initial state) at index - * (feature_idx, warp, block). - */ - for (int i = warp; i < n_dims; i += CEIL_DIV(blockDim.x, 32)) { - scalar_t cum_decay = static_cast(1.0); - scalar_t h = static_cast(0.0); - if (blockIdx.x == 0 && lane == 0 && initial_state != NULL) { - h = initial_state[i]; - } - - for (int t = warp_start; t < warp_stop; t++) { - cum_decay *= decays[i * n_steps + t]; - h = decays[i * n_steps + t] * h + impulses[i * n_steps + t]; - } - - // TODO: store into shared memory, work in shared memory sized blocks - // store into global memory - decay_storage[i + lane * n_dims] = cum_decay; - h_storage[i + lane * n_dims] = h; - } - - __syncthreads(); - - /* - * Reduce over warps. - * After this loop exits, the storage arrays should contain the reduction - * from block_start to block_finish (including initial state) at index - * (feature_idx, 32, block). - */ - // TODO: parallel reduction (or scan). Need to worry about changing the warp - // reduction values (as I use them again later) - for (int i = threadIdx.x; i < n_dims; i += blockDim.x) { - scalar_t cum_decay = static_cast(1.0); - scalar_t h = static_cast(0.0); - for (int t = 0; t < 32; t++) { - cum_decay *= decay_storage[i + t * n_dims]; - h = decay_storage[i + t * n_dims] * h + h_storage[i + t * n_dims]; - } - decay_storage[i + 32 * n_dims] = cum_decay; - h_storage[i + 32 * n_dims] = h; - } -} - -template -__global__ void block_scan_kernel(scalar_t *decay_storage, scalar_t *h_storage, - int n_dims, int n_blocks) { - /* - * Scan over blocks. - * After this loop exits, the storage arrays should contain the cumulative - * sum from block_idx 0 to i (inclusive) at index (feature_idx, 32, i) This - * means (feature_idx, 32, 2) contains the reduction of blocks 0, 1, and 2. - */ - // TODO: parallel scan (tricky because number of blocks isn't necessarily - // smaller than number of warps that can fit in a single block) - for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < n_dims; - i += blockDim.x * gridDim.x) { - for (int t = 1; t < n_blocks; t++) { - int cur_idx = i + 32 * n_dims + t * 33 * n_dims; - int prev_idx = i + 32 * n_dims + (t - 1) * 33 * n_dims; - - // TODO: remove unneccessary reads from global memory (prev_idx - // accesses) - h_storage[cur_idx] = decay_storage[cur_idx] * h_storage[prev_idx] + - h_storage[cur_idx]; - decay_storage[cur_idx] *= decay_storage[prev_idx]; - } - } -} - -template -__global__ void warp_scan_kernel(const scalar_t *decays, - const scalar_t *impulses, - const scalar_t *initial_state, scalar_t *out, - scalar_t *decay_storage, scalar_t *h_storage, - int n_dims, int n_steps) { - int warp = threadIdx.x / 32; - int lane = threadIdx.x % 32; - - // Note: Due to the index ordering of the storage arrays, the following - // indices are equivalent: - // - // i + (t - 1) * n_dims + blockIdx.x * 33 * n_dims - // i + 32 * n_dims + (blockIdx.x - 1) * 33 * n_dims - // - // when t is 0. This means something that looks like negative indexing - // (t-1) can be used to safely access the stored value for the previous - // warp (even if the previous warp belonged to the previous block). - - /* - * Scan over warps. - * After this loop executes, the storage arrays should contain the - * cumulative sum from the beginning of sequence (including initial - * condition) up to and including the indexed warp and block. - */ - // TODO: parallel scan - for (int i = threadIdx.x; i < n_dims; i += blockDim.x) { - for (int t = 0; t < 32; t++) { - if (t == 0 && blockIdx.x == 0) { - // the reduction over warp 0 (including initial condition) is - // correct val for scan, so there's no work to do - continue; - } - - int cur_idx = i + t * n_dims + blockIdx.x * 33 * n_dims; - int prev_idx = i + (t - 1) * n_dims + blockIdx.x * 33 * n_dims; - h_storage[cur_idx] = decay_storage[cur_idx] * h_storage[prev_idx] + - h_storage[cur_idx]; - decay_storage[cur_idx] *= decay_storage[prev_idx]; - } - } - - __syncthreads(); - - int2 start_stop = - compute_warp_start_stop(blockIdx.x, lane, gridDim.x, n_steps); - int warp_start = start_stop.x; - int warp_stop = start_stop.y; - - /* - * Scan within warps. - * This loop writes to the output array. Each warp reads in it's initial - * state (either from the "initial_state" or the storage arrays) and then - * writes to output for indices warp_start up to warp_stop. - */ - for (int i = warp; i < n_dims; i += CEIL_DIV(blockDim.x, 32)) { - scalar_t h = static_cast(0.0); - if (blockIdx.x == 0 && lane == 0) { - if (initial_state != NULL) { - h = initial_state[i]; - } - } else { - h = h_storage[i + (lane - 1) * n_dims + blockIdx.x * 33 * n_dims]; - } - - for (int t = warp_start; t < warp_stop; t++) { - h = decays[i * n_steps + t] * h + impulses[i * n_steps + t]; - out[i * n_steps + t] = h; - } - } -} - -/* - * This is the main method for the prefix sum kernels. - * decays, impulses, out: - * each a n_dims x n_steps column major matrix located on GPU - * initial_state: - * array of size n_dims located on GPU - */ -template -void compute_linear_recurrence(const scalar_t *decays, const scalar_t *impulses, - const scalar_t *initial_state, scalar_t *out, - int n_dims, int n_steps) { - // we want at least 32 elements per block, but no reason to run - // with more than the maximum number of concurrent blocks - // NOTE: 128 is decided empirically. - int n_blocks = min(CEIL_DIV(n_steps, 32), 128); - - // TODO: make user pass in working memory? This allows integration - // with CNMeM (used by Theano) - int reduction_mem_sz = 2 * n_blocks * 33 * n_dims * sizeof(scalar_t); - scalar_t *d_reduction_mem; - gpuErrChk(cudaMalloc(&d_reduction_mem, reduction_mem_sz)); - scalar_t *d_decay_storage = &d_reduction_mem[0 * n_blocks * 33 * n_dims]; - scalar_t *d_h_storage = &d_reduction_mem[1 * n_blocks * 33 * n_dims]; - - // TODO: run kernels on non-default stream? - reduction_kernel<<>>(decays, impulses, initial_state, - d_decay_storage, d_h_storage, n_dims, - n_steps); - - block_scan_kernel<<>>(d_decay_storage, d_h_storage, n_dims, - n_blocks); - - warp_scan_kernel<<>>(decays, impulses, initial_state, out, - d_decay_storage, d_h_storage, n_dims, - n_steps); - - gpuErrChk(cudaFree(d_reduction_mem)); -} - -at::Tensor scan_cuda_wrapper(const at::Tensor &input, const at::Tensor &weights, - const at::Tensor &initials) { - TORCH_CHECK(input.is_floating_point() || input.is_complex(), - "Input must be floating point or complex"); - TORCH_CHECK(initials.scalar_type() == input.scalar_type(), - "Initials must have the same scalar type as input"); - TORCH_CHECK(weights.scalar_type() == input.scalar_type(), - "Weights must have the same scalar type as input"); - - auto input_contiguous = input.contiguous(); - auto weights_contiguous = weights.contiguous(); - auto output = at::empty_like(input_contiguous); - - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( - input.scalar_type(), "compute_linear_recurrence", [&] { - compute_linear_recurrence( - weights_contiguous.const_data_ptr(), - input_contiguous.const_data_ptr(), - initials.const_data_ptr(), - output.mutable_data_ptr(), input_contiguous.size(0), - input_contiguous.size(1)); - }); - return output.contiguous(); -} - -TORCH_LIBRARY_IMPL(torchlpc, CUDA, m) { m.impl("scan", &scan_cuda_wrapper); } diff --git a/torchlpc/csrc/cuda/scan.cu b/torchlpc/csrc/cuda/scan.cu new file mode 100644 index 0000000..d4fc6f7 --- /dev/null +++ b/torchlpc/csrc/cuda/scan.cu @@ -0,0 +1,108 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +template +struct recur_binary_op { + __host__ __device__ cuda::std::tuple operator()( + const cuda::std::tuple &a, + const cuda::std::tuple &b) const { + auto a_first = thrust::get<0>(a); + auto a_second = thrust::get<1>(a); + auto b_first = thrust::get<0>(b); + auto b_second = thrust::get<1>(b); + return cuda::std::make_tuple(a_first * b_first, + a_second * b_first + b_second); + } +}; + +template +struct output_unary_op { + __host__ __device__ T + operator()(const cuda::std::tuple &state) const { + return thrust::get<1>(state); + } +}; + +template +__host__ __device__ void compute_linear_recurrence(const scalar_t *decays, + const scalar_t *impulses, + scalar_t *out, int n_steps) { + thrust::inclusive_scan( + thrust::device, thrust::make_zip_iterator(decays, impulses), + thrust::make_zip_iterator(decays + n_steps, impulses + n_steps), + thrust::make_transform_output_iterator(out, + output_unary_op()), + recur_binary_op()); +} + +template +struct scan_functor { + const T *decays, *impulses; + T *out; + int n_steps; + __host__ __device__ void operator()(int i) { + compute_linear_recurrence(decays + i * n_steps, + impulses + i * n_steps, out + i * n_steps, + n_steps); + } +}; + +template +void compute_linear_recurrence2(const scalar_t *decays, + const scalar_t *impulses, + // const scalar_t *initials, + scalar_t *out, int n_dims, int n_steps) { + thrust::counting_iterator it(0); + scan_functor scan_op{decays, impulses, out, n_steps}; + thrust::for_each(thrust::device, it, it + n_dims, scan_op); +} + +at::Tensor scan_cuda_wrapper(const at::Tensor &input, const at::Tensor &weights, + const at::Tensor &initials) { + TORCH_CHECK(input.is_floating_point() || input.is_complex(), + "Input must be floating point or complex"); + TORCH_CHECK(initials.scalar_type() == input.scalar_type(), + "Initials must have the same scalar type as input"); + TORCH_CHECK(weights.scalar_type() == input.scalar_type(), + "Weights must have the same scalar type as input"); + + auto input_contiguous = + at::cat({initials.unsqueeze(1), input}, 1).contiguous(); + auto weights_contiguous = + at::cat({at::zeros_like(initials.unsqueeze(1)), weights}, 1) + .contiguous(); + auto output = at::empty_like(input_contiguous); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( + input.scalar_type(), "compute_linear_recurrence", [&] { + compute_linear_recurrence( + weights_contiguous.const_data_ptr(), + input_contiguous.const_data_ptr(), + output.mutable_data_ptr(), input_contiguous.numel()); + // compute_linear_recurrence2( + // weights_contiguous.const_data_ptr(), + // input_contiguous.const_data_ptr(), + // // initials.const_data_ptr(), + // output.mutable_data_ptr(), + // input_contiguous.size(0), input_contiguous.size(1)); + }); + return output.slice(1, 1, output.size(1)) + .contiguous(); // Remove the initial state from the output +} + +TORCH_LIBRARY_IMPL(torchlpc, CUDA, m) { m.impl("scan", &scan_cuda_wrapper); } \ No newline at end of file