diff --git a/src/libtorchaudio/iir_cuda.cu b/src/libtorchaudio/iir_cuda.cu index 658bca4c54..31919ee617 100644 --- a/src/libtorchaudio/iir_cuda.cu +++ b/src/libtorchaudio/iir_cuda.cu @@ -69,12 +69,12 @@ Tensor cuda_lfilter_core_loop( const dim3 blocks((N * C + threads.x - 1) / threads.x); THO_DISPATCH_V2( - in.scalar_type(), "iir_cu_loop", AT_WRAP([&] { - (iir_cu_kernel<<>>( - torchaudio::packed_accessor_size_t(in), - torchaudio::packed_accessor_size_t(a_flipped), - torchaudio::packed_accessor_size_t(padded_out))); - STD_CUDA_KERNEL_LAUNCH_CHECK(); - }), AT_FLOATING_TYPES); + in.scalar_type(), "iir_cu_loop", AT_WRAP(([&]() { + iir_cu_kernel<<>>( + torchaudio::packed_accessor_size_t(in), + torchaudio::packed_accessor_size_t(a_flipped), + torchaudio::packed_accessor_size_t(padded_out)); + STD_CUDA_KERNEL_LAUNCH_CHECK(); + })), AT_FLOATING_TYPES); return padded_out; } diff --git a/tools/setup_helpers/extension.py b/tools/setup_helpers/extension.py index 447dd5091d..a8d0689cb5 100644 --- a/tools/setup_helpers/extension.py +++ b/tools/setup_helpers/extension.py @@ -34,7 +34,7 @@ def _get_build(var, default=False): _USE_ROCM = _get_build("USE_ROCM", torch.backends.cuda.is_built() and torch.version.hip is not None) _USE_CUDA = _get_build("USE_CUDA", torch.backends.cuda.is_built() and torch.version.hip is None) _BUILD_ALIGN = _get_build("BUILD_ALIGN", True) -_BUILD_CUDA_CTC_DECODER = _get_build("BUILD_CUDA_CTC_DECODER", _USE_CUDA) +_BUILD_CUDA_CTC_DECODER = _get_build("BUILD_CUDA_CTC_DECODER", _USE_CUDA or _USE_ROCM) _USE_OPENMP = _get_build("USE_OPENMP", True) and "ATen parallel backend: OpenMP" in torch.__config__.parallel_info() _TORCH_CUDA_ARCH_LIST = os.environ.get("TORCH_CUDA_ARCH_LIST", None) @@ -71,6 +71,17 @@ def get_ext_modules(): extension = CUDAExtension extra_compile_args["cxx"].append("-DUSE_CUDA") extra_compile_args["nvcc"] = ["-O2", "-DUSE_CUDA"] + if _USE_ROCM: + extension = CUDAExtension + extra_compile_args["nvcc"] = ["-O3"] + # TORCH_HIP_VERSION is used by hipified C++ (e.g. utils_hip.cpp); PyTorch only defines it when building PyTorch. + if torch.version.hip: + parts = torch.version.hip.split(".") + major = int(parts[0]) if len(parts) > 0 else 0 + minor = int(parts[1]) if len(parts) > 1 else 0 + torch_hip_version = major * 100 + minor # e.g. 7.1.x -> 701 + extra_compile_args["cxx"].append("-DTORCH_HIP_VERSION=" + str(torch_hip_version)) + extra_compile_args["nvcc"].append("-DTORCH_HIP_VERSION=" + str(torch_hip_version)) sources = [ "utils.cpp", @@ -78,7 +89,7 @@ def get_ext_modules(): "overdrive.cpp", ] - if _USE_CUDA: + if _USE_CUDA or _USE_ROCM: sources.append("iir_cuda.cu") if _BUILD_RNNT: @@ -88,7 +99,7 @@ def get_ext_modules(): "rnnt/compute.cpp", ] ) - if _USE_CUDA: + if _USE_CUDA or _USE_ROCM: sources.append("rnnt/gpu/compute.cu") if _BUILD_ALIGN: @@ -99,7 +110,7 @@ def get_ext_modules(): "forced_align/compute.cpp", ] ) - if _USE_CUDA: + if _USE_CUDA or _USE_ROCM: sources.append("forced_align/gpu/compute.cu") modules = [