diff --git a/.gitmodules b/.gitmodules index e69de29bb2..25d307cea8 100644 --- a/.gitmodules +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third_party/hipify_torch"] + path = third_party/hipify_torch + url = https://github.com/ROCmSoftwarePlatform/hipify_torch diff --git a/CMakeLists.txt b/CMakeLists.txt index ddc6dc15a2..42ae6d66b3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -70,11 +70,16 @@ if(USE_CUDA AND USE_ROCM) message(FATAL "CUDA and ROCm are mutually exclusive") endif() +find_package(Torch REQUIRED) + if(USE_ROCM) + + enable_language(HIP) + # Find the HIP package, set the HIP paths, load the HIP CMake. include(cmake/LoadHIP.cmake) if(NOT PYTORCH_FOUND_HIP) - set(USE_ROCM OFF) + #set(USE_ROCM OFF) endif() endif() diff --git a/src/libtorchaudio/CMakeLists.txt b/src/libtorchaudio/CMakeLists.txt index 713cb50533..e504900378 100644 --- a/src/libtorchaudio/CMakeLists.txt +++ b/src/libtorchaudio/CMakeLists.txt @@ -1,6 +1,23 @@ ################################################################################ # libtorchaudio ################################################################################ + +if(USE_ROCM) + list (APPEND CMAKE_PREFIX_PATH /opt/rocm/hip /opt/rocm) + FIND_PACKAGE(HIP REQUIRED) + MESSAGE(STATUS "hip found ${ROCM_FOUND}") + + list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/third_party/hipify_torch/cmake") + include(Hipify) + + set(CMAKE_CXX_COMPILER ${HIP_HIPCC_EXECUTABLE}) + set(CMAKE_CXX_LINKER ${HIP_HIPCC_EXECUTABLE}) + set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) + list( APPEND CMAKE_INSTALL_RPATH "/opt/rocm/llvm/lib" ) + +endif() + + set( sources lfilter.cpp @@ -39,6 +56,23 @@ if(BUILD_RNNT) rnnt/gpu/compute.cu ) endif() + + if (USE_ROCM) + hipify(CUDA_SOURCE_DIR ${PROJECT_SOURCE_DIR}/src/libtorchaudio/rnnt/gpu HIP_SOURCE_DIR ${PROJECT_SOURCE_DIR}/src/libtorchaudio/rnnt/hip) + if ( NOT HIP_ADD_LIBRARY_FOUND ) + list(APPEND CMAKE_MODULE_PATH /opt/rocm/hip/cmake) + find_package(HIP REQUIRED) + endif() + + list( + APPEND + sources + rnnt/hip/compute_alphas.hip + rnnt/hip/compute_betas.hip + rnnt/hip/compute.hip + ) + endif() + endif() if(BUILD_RIR) @@ -76,12 +110,29 @@ if(USE_CUDA) ) endif() -if(OpenMP_CXX_FOUND) +if(USE_ROCM) list( APPEND - additional_libs - OpenMP::OpenMP_CXX + additional_libs + hip::host + hip::device ) + list( + APPEND + compile_definitions + USE_ROCM + ) +endif() + + +if(USE_CUDA) + if(OpenMP_CXX_FOUND) + list( + APPEND + additional_libs + OpenMP::OpenMP_CXX + ) + endif() endif() #------------------------------------------------------------------------------# diff --git a/src/libtorchaudio/rnnt/gpu/compute.cu b/src/libtorchaudio/rnnt/gpu/compute.cu index 43dae68027..b48443af36 100644 --- a/src/libtorchaudio/rnnt/gpu/compute.cu +++ b/src/libtorchaudio/rnnt/gpu/compute.cu @@ -1,6 +1,10 @@ #include -#include #include +#ifdef __HIP_PLATFORM_AMD__ +#include +#else +#include +#endif namespace torchaudio { namespace rnnt { diff --git a/src/libtorchaudio/rnnt/gpu/compute_alphas.cu b/src/libtorchaudio/rnnt/gpu/compute_alphas.cu index bde40daa9f..d19ace1bec 100644 --- a/src/libtorchaudio/rnnt/gpu/compute_alphas.cu +++ b/src/libtorchaudio/rnnt/gpu/compute_alphas.cu @@ -1,6 +1,10 @@ #include -#include #include +#ifdef __HIP_PLATFORM_AMD__ +#include +#else +#include +#endif namespace torchaudio { namespace rnnt { diff --git a/src/libtorchaudio/rnnt/gpu/compute_betas.cu b/src/libtorchaudio/rnnt/gpu/compute_betas.cu index 18857c4388..d57be262f2 100644 --- a/src/libtorchaudio/rnnt/gpu/compute_betas.cu +++ b/src/libtorchaudio/rnnt/gpu/compute_betas.cu @@ -1,6 +1,10 @@ #include -#include #include +#ifdef __HIP_PLATFORM_AMD__ +#include +#else +#include +#endif namespace torchaudio { namespace rnnt { diff --git a/src/libtorchaudio/rnnt/gpu/gpu_kernel_utils.cuh b/src/libtorchaudio/rnnt/gpu/gpu_kernel_utils.cuh index f4ad3add2b..af4e6608f2 100644 --- a/src/libtorchaudio/rnnt/gpu/gpu_kernel_utils.cuh +++ b/src/libtorchaudio/rnnt/gpu/gpu_kernel_utils.cuh @@ -2,7 +2,11 @@ #ifdef USE_CUDA +#ifdef __HIP_PLATFORM_AMD__ +#include +#else #include +#endif namespace torchaudio { namespace rnnt { diff --git a/src/libtorchaudio/rnnt/gpu/gpu_kernels.cuh b/src/libtorchaudio/rnnt/gpu/gpu_kernels.cuh index 136e6844f2..eb22fe8cdf 100644 --- a/src/libtorchaudio/rnnt/gpu/gpu_kernels.cuh +++ b/src/libtorchaudio/rnnt/gpu/gpu_kernels.cuh @@ -4,9 +4,15 @@ #include +#ifdef __HIP_PLATFORM_AMD__ +#include +#include +#include +#else #include #include #include +#endif namespace torchaudio { namespace rnnt { diff --git a/src/libtorchaudio/rnnt/gpu/gpu_transducer.h b/src/libtorchaudio/rnnt/gpu/gpu_transducer.h index 875c47974f..570935b3be 100644 --- a/src/libtorchaudio/rnnt/gpu/gpu_transducer.h +++ b/src/libtorchaudio/rnnt/gpu/gpu_transducer.h @@ -3,8 +3,13 @@ #ifdef USE_CUDA #include +#ifdef __HIP_PLATFORM_AMD__ +#include +#include +#else #include #include +#endif namespace torchaudio { namespace rnnt { diff --git a/src/libtorchaudio/rnnt/gpu/kernel_utils.h b/src/libtorchaudio/rnnt/gpu/kernel_utils.h index 9cfaf42cdd..57450f13a8 100644 --- a/src/libtorchaudio/rnnt/gpu/kernel_utils.h +++ b/src/libtorchaudio/rnnt/gpu/kernel_utils.h @@ -2,7 +2,11 @@ #include +#ifdef __HIP_PLATFORM_AMD__ +#include +#else #include +#endif namespace torchaudio { namespace rnnt { diff --git a/src/libtorchaudio/rnnt/gpu/kernels.h b/src/libtorchaudio/rnnt/gpu/kernels.h index 5f327d3ee3..0f15d50ed8 100644 --- a/src/libtorchaudio/rnnt/gpu/kernels.h +++ b/src/libtorchaudio/rnnt/gpu/kernels.h @@ -2,8 +2,13 @@ #include +#ifdef __HIP_PLATFORM_AMD__ +#include +#include +#else #include #include +#endif namespace torchaudio { namespace rnnt { diff --git a/src/libtorchaudio/rnnt/macros.h b/src/libtorchaudio/rnnt/macros.h index cdc83dd5d2..f1677f9198 100644 --- a/src/libtorchaudio/rnnt/macros.h +++ b/src/libtorchaudio/rnnt/macros.h @@ -8,6 +8,14 @@ #define FORCE_INLINE __forceinline__ #include #include +#elif USE_ROCM +#define WARP_SIZE 32 +#define MAX_THREADS_PER_BLOCK 1024 +#define REDUCE_THREADS 256 +#define HOST_AND_DEVICE __host__ __device__ +#define FORCE_INLINE __forceinline__ +#include +#include #else #define HOST_AND_DEVICE #define FORCE_INLINE inline diff --git a/src/libtorchaudio/rnnt/options.h b/src/libtorchaudio/rnnt/options.h index 8a8fed1116..c3b5bdfa4d 100644 --- a/src/libtorchaudio/rnnt/options.h +++ b/src/libtorchaudio/rnnt/options.h @@ -2,7 +2,12 @@ #ifdef USE_CUDA #include +typedef cudaStream_t gpuStream_t; #endif // USE_CUDA +#ifdef USE_ROCM +#include +typedef hipStream_t gpuStream_t; +#endif // USE_ROCM #include #include @@ -13,9 +18,9 @@ namespace rnnt { struct Options { // the device to compute transducer loss. device_t device_; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) // the stream to launch kernels in when using GPU. - cudaStream_t stream_; + gpuStream_t stream_; #endif // The maximum number of threads that can be used. int numThreads_; diff --git a/src/libtorchaudio/rnnt/workspace.h b/src/libtorchaudio/rnnt/workspace.h index b4bbb30a43..0d457c5c78 100644 --- a/src/libtorchaudio/rnnt/workspace.h +++ b/src/libtorchaudio/rnnt/workspace.h @@ -133,10 +133,22 @@ class IntWorkspace { ComputeSizeForBetaCounters(options_) * sizeof(int)); } #endif // USE_CUDA +#ifdef USE_ROCM + if (data_ != nullptr && options_.device_ == GPU) { + hipMemset( + GetPointerToAlphaCounters(), + 0, + ComputeSizeForAlphaCounters(options_) * sizeof(int)); + hipMemset( + GetPointerToBetaCounters(), + 0, + ComputeSizeForBetaCounters(options_) * sizeof(int)); + } +#endif // USE_ROCM } static int ComputeSizeForAlphaCounters(const Options& options) { // B * U -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) if (options.device_ == GPU) { return options.BU(); } else { @@ -147,7 +159,7 @@ class IntWorkspace { #endif // USE_CUDA } static int ComputeSizeForBetaCounters(const Options& options) { // B * U -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) if (options.device_ == GPU) { return options.BU(); } else { diff --git a/third_party/hipify_torch b/third_party/hipify_torch new file mode 160000 index 0000000000..a4337c69fe --- /dev/null +++ b/third_party/hipify_torch @@ -0,0 +1 @@ +Subproject commit a4337c69fe0e2552a7b7b0669178926beeed828c