Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
5484f33
fix build error on ROCM (#2)
micmelesse Aug 27, 2021
bdddf4b
Fix build error 2 rocm pr 2 (#3)
micmelesse Aug 30, 2021
673faa1
Disable tests (#5)
micmelesse Sep 8, 2021
c516a52
remove cmakelists, since it is moved inside, resolve import statement
amd-sriram Feb 16, 2026
b232095
add source code for rnnt loss, define TORCH_HIP_VERSION variable
amd-sriram Feb 19, 2026
fa315a1
add namespace shim file
amd-sriram Feb 19, 2026
c13048a
fix THO_DISPATCH syntax so that hipification works
amd-sriram Feb 20, 2026
6c8ab7f
fix torch version, add lfilter rocm to sources
amd-sriram Feb 20, 2026
5b46635
add rocm source code for forced align, add shim namespace file for rocm
amd-sriram Feb 21, 2026
ab3424f
remove extra skip if rocm flags for certain unit tests
amd-sriram Feb 21, 2026
45aec76
build cuda ctc decoder for rocm
amd-sriram Feb 21, 2026
bc95506
the lfilter test passes, the other tests are the same, so removing sk…
amd-sriram Feb 21, 2026
553f9b0
add end of file
amd-sriram Feb 23, 2026
f795931
fix ufmt issue
amd-sriram Feb 23, 2026
e7bad35
fix clang format
amd-sriram Feb 23, 2026
82acc52
remove skip if rocm flag from the test
amd-sriram Feb 23, 2026
b9c7682
Merge pull request #12 from ROCm/rocm_rnnt_loss_feature
amd-sriram Feb 23, 2026
e132d40
Merge branch 'pytorch:main' into main
amd-sriram Feb 23, 2026
11bbe29
remove hipblas flags
amd-sriram Mar 1, 2026
189dda5
Merge pull request #13 from ROCm/fix_rnnt_loss
amd-sriram Mar 1, 2026
cf99efb
remove hip namespace shim
amd-sriram Mar 5, 2026
faf4a35
remove reference of hip namespace shim
amd-sriram Mar 6, 2026
4c189d2
Merge pull request #14 from ROCm/fix_rnnt_loss_r211
amd-sriram Mar 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/libtorchaudio/iir_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ Tensor cuda_lfilter_core_loop(
const dim3 blocks((N * C + threads.x - 1) / threads.x);

THO_DISPATCH_V2(
in.scalar_type(), "iir_cu_loop", AT_WRAP([&] {
(iir_cu_kernel<scalar_t><<<blocks, threads>>>(
torchaudio::packed_accessor_size_t<scalar_t, 3>(in),
torchaudio::packed_accessor_size_t<scalar_t, 2>(a_flipped),
torchaudio::packed_accessor_size_t<scalar_t, 3>(padded_out)));
STD_CUDA_KERNEL_LAUNCH_CHECK();
}), AT_FLOATING_TYPES);
in.scalar_type(), "iir_cu_loop", AT_WRAP(([&]() {
iir_cu_kernel<scalar_t><<<blocks, threads>>>(
torchaudio::packed_accessor_size_t<scalar_t, 3>(in),
torchaudio::packed_accessor_size_t<scalar_t, 2>(a_flipped),
torchaudio::packed_accessor_size_t<scalar_t, 3>(padded_out));
STD_CUDA_KERNEL_LAUNCH_CHECK();
})), AT_FLOATING_TYPES);
return padded_out;
}
19 changes: 15 additions & 4 deletions tools/setup_helpers/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _get_build(var, default=False):
_USE_ROCM = _get_build("USE_ROCM", torch.backends.cuda.is_built() and torch.version.hip is not None)
_USE_CUDA = _get_build("USE_CUDA", torch.backends.cuda.is_built() and torch.version.hip is None)
_BUILD_ALIGN = _get_build("BUILD_ALIGN", True)
_BUILD_CUDA_CTC_DECODER = _get_build("BUILD_CUDA_CTC_DECODER", _USE_CUDA)
_BUILD_CUDA_CTC_DECODER = _get_build("BUILD_CUDA_CTC_DECODER", _USE_CUDA or _USE_ROCM)
_USE_OPENMP = _get_build("USE_OPENMP", True) and "ATen parallel backend: OpenMP" in torch.__config__.parallel_info()
_TORCH_CUDA_ARCH_LIST = os.environ.get("TORCH_CUDA_ARCH_LIST", None)

Expand Down Expand Up @@ -71,14 +71,25 @@ def get_ext_modules():
extension = CUDAExtension
extra_compile_args["cxx"].append("-DUSE_CUDA")
extra_compile_args["nvcc"] = ["-O2", "-DUSE_CUDA"]
if _USE_ROCM:
extension = CUDAExtension
extra_compile_args["nvcc"] = ["-O3"]
# TORCH_HIP_VERSION is used by hipified C++ (e.g. utils_hip.cpp); PyTorch only defines it when building PyTorch.
if torch.version.hip:
parts = torch.version.hip.split(".")
major = int(parts[0]) if len(parts) > 0 else 0
minor = int(parts[1]) if len(parts) > 1 else 0
torch_hip_version = major * 100 + minor # e.g. 7.1.x -> 701
extra_compile_args["cxx"].append("-DTORCH_HIP_VERSION=" + str(torch_hip_version))
extra_compile_args["nvcc"].append("-DTORCH_HIP_VERSION=" + str(torch_hip_version))

sources = [
"utils.cpp",
"lfilter.cpp",
"overdrive.cpp",
]

if _USE_CUDA:
if _USE_CUDA or _USE_ROCM:
sources.append("iir_cuda.cu")

if _BUILD_RNNT:
Expand All @@ -88,7 +99,7 @@ def get_ext_modules():
"rnnt/compute.cpp",
]
)
if _USE_CUDA:
if _USE_CUDA or _USE_ROCM:
sources.append("rnnt/gpu/compute.cu")

if _BUILD_ALIGN:
Expand All @@ -99,7 +110,7 @@ def get_ext_modules():
"forced_align/compute.cpp",
]
)
if _USE_CUDA:
if _USE_CUDA or _USE_ROCM:
sources.append("forced_align/gpu/compute.cu")

modules = [
Expand Down