Skip to content

Commit 1bbf78b

Browse files
committed
[STABLE ABI] Port lfilter
1 parent 32ce8c0 commit 1bbf78b

File tree

5 files changed

+202
-114
lines changed

5 files changed

+202
-114
lines changed

src/libtorchaudio/iir_cuda.cu

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1+
#include <libtorchaudio/utils.h>
2+
#include <torch/headeronly/core/Dispatch_v2.h>
3+
#include <torch/headeronly/core/ScalarType.h>
14
#include <c10/cuda/CUDAException.h>
25
#include <c10/cuda/CUDAGuard.h>
3-
#include <torch/torch.h>
6+
#include <c10/core/DeviceGuard.h>
7+
8+
using torch::headeronly::ScalarType;
9+
using torch::stable::Tensor;
410

511
template <typename scalar_t>
612
__global__ void iir_cu_kernel(
7-
const torch::
8-
PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> in,
9-
const torch::
10-
PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>
11-
a_flipped,
12-
torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>
13-
out) {
13+
const torchaudio::PackedTensorAccessorSizeT<scalar_t, 3> in,
14+
const torchaudio::PackedTensorAccessorSizeT<scalar_t, 2> a_flipped,
15+
torchaudio::PackedTensorAccessorSizeT<scalar_t, 3> out) {
1416
int64_t tid = blockIdx.x * blockDim.x + threadIdx.x;
1517
int64_t n = in.size(0);
1618
int64_t c = in.size(1);
@@ -33,51 +35,49 @@ __global__ void iir_cu_kernel(
3335
}
3436
}
3537

36-
void cuda_lfilter_core_loop(
37-
const torch::Tensor& in,
38-
const torch::Tensor& a_flipped,
39-
torch::Tensor& padded_out) {
40-
TORCH_CHECK(
41-
in.device().is_cuda() && a_flipped.device().is_cuda() &&
42-
padded_out.device().is_cuda());
38+
Tensor cuda_lfilter_core_loop(
39+
Tensor in,
40+
Tensor a_flipped,
41+
Tensor padded_out) {
42+
STD_TORCH_CHECK(
43+
in.is_cuda() && a_flipped.is_cuda() &&
44+
padded_out.is_cuda());
4345

44-
TORCH_CHECK(
46+
STD_TORCH_CHECK(
47+
(in.get_device_index() == a_flipped.get_device_index()) &&
48+
(in.get_device_index() == padded_out.get_device_index()));
49+
50+
STD_TORCH_CHECK(
4551
in.is_contiguous() && a_flipped.is_contiguous() &&
4652
padded_out.is_contiguous());
4753

48-
TORCH_CHECK(
49-
(in.dtype() == torch::kFloat32 || in.dtype() == torch::kFloat64) &&
50-
(a_flipped.dtype() == torch::kFloat32 ||
51-
a_flipped.dtype() == torch::kFloat64) &&
52-
(padded_out.dtype() == torch::kFloat32 ||
53-
padded_out.dtype() == torch::kFloat64));
54+
STD_TORCH_CHECK(
55+
(in.scalar_type() == ScalarType::Float || in.scalar_type() == ScalarType::Double) &&
56+
(a_flipped.scalar_type() == ScalarType::Float ||
57+
a_flipped.scalar_type() == ScalarType::Double) &&
58+
(padded_out.scalar_type() == ScalarType::Float ||
59+
padded_out.scalar_type() == ScalarType::Double));
5460

5561
const int N = in.size(0);
5662
const int C = in.size(1);
57-
TORCH_CHECK(N == padded_out.size(0));
58-
TORCH_CHECK(C == padded_out.size(1));
63+
STD_TORCH_CHECK(N == padded_out.size(0));
64+
STD_TORCH_CHECK(C == padded_out.size(1));
5965

60-
TORCH_CHECK(in.size(2) + a_flipped.size(1) - 1 == padded_out.size(2));
66+
STD_TORCH_CHECK(in.size(2) + a_flipped.size(1) - 1 == padded_out.size(2));
6167

62-
const at::cuda::OptionalCUDAGuard device_guard(device_of(in));
68+
// TODO: enable device guard:
69+
//const at::cuda::OptionalCUDAGuard device_guard(in.device());
6370

6471
const dim3 threads(256);
6572
const dim3 blocks((N * C + threads.x - 1) / threads.x);
6673

67-
AT_DISPATCH_FLOATING_TYPES(
68-
in.scalar_type(), "iir_cu_loop", ([&] {
69-
iir_cu_kernel<scalar_t><<<blocks, threads>>>(
70-
in.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(),
71-
a_flipped.packed_accessor<
72-
scalar_t,
73-
2,
74-
torch::RestrictPtrTraits,
75-
size_t>(),
76-
padded_out.packed_accessor<
77-
scalar_t,
78-
3,
79-
torch::RestrictPtrTraits,
80-
size_t>());
74+
THO_DISPATCH_V2(
75+
in.scalar_type(), "iir_cu_loop", AT_WRAP([&] {
76+
(iir_cu_kernel<scalar_t><<<blocks, threads>>>(
77+
torchaudio::packed_accessor_size_t<scalar_t, 3>(in),
78+
torchaudio::packed_accessor_size_t<scalar_t, 2>(a_flipped),
79+
torchaudio::packed_accessor_size_t<scalar_t, 3>(padded_out)));
8180
C10_CUDA_KERNEL_LAUNCH_CHECK();
82-
}));
81+
}), AT_FLOATING_TYPES);
82+
return padded_out;
8383
}

src/libtorchaudio/iir_cuda.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
#pragma once
22

3-
#include <torch/types.h>
3+
#include <torch/csrc/stable/tensor.h>
44

5-
void cuda_lfilter_core_loop(
6-
const torch::Tensor& in,
7-
const torch::Tensor& a_flipped,
8-
torch::Tensor& padded_out);
5+
using torch::stable::Tensor;
6+
7+
Tensor cuda_lfilter_core_loop(Tensor in, Tensor a_flipped, Tensor padded_out);

src/libtorchaudio/lfilter.cpp

Lines changed: 92 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,116 +1,140 @@
1-
#include <torch/script.h>
2-
#include <torch/torch.h>
1+
#include <libtorchaudio/utils.h>
2+
#include <torch/csrc/stable/library.h>
3+
#include <torch/csrc/stable/ops.h>
4+
#include <torch/headeronly/core/Dispatch_v2.h>
5+
#include <torch/headeronly/core/ScalarType.h>
36

47
#ifdef USE_CUDA
58
#include <libtorchaudio/iir_cuda.h>
69
#endif
710

811
namespace {
912

13+
using torch::headeronly::ScalarType;
14+
using torch::stable::Tensor;
15+
1016
template <typename scalar_t>
1117
void host_lfilter_core_loop(
12-
const torch::Tensor& input_signal_windows,
13-
const torch::Tensor& a_coeff_flipped,
14-
torch::Tensor& padded_output_waveform) {
18+
const Tensor& input_signal_windows,
19+
const Tensor& a_coeff_flipped,
20+
Tensor& padded_output_waveform) {
1521
int64_t n_batch = input_signal_windows.size(0);
1622
int64_t n_channel = input_signal_windows.size(1);
1723
int64_t n_samples_input = input_signal_windows.size(2);
1824
int64_t n_samples_output = padded_output_waveform.size(2);
1925
int64_t n_order = a_coeff_flipped.size(1);
20-
scalar_t* output_data = padded_output_waveform.data_ptr<scalar_t>();
21-
const scalar_t* input_data = input_signal_windows.data_ptr<scalar_t>();
22-
const scalar_t* a_coeff_flipped_data = a_coeff_flipped.data_ptr<scalar_t>();
23-
24-
at::parallel_for(0, n_channel * n_batch, 1, [&](int64_t begin, int64_t end) {
25-
for (auto i = begin; i < end; i++) {
26-
int64_t offset_input = i * n_samples_input;
27-
int64_t offset_output = i * n_samples_output;
28-
int64_t i_channel = i % n_channel;
29-
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
30-
scalar_t a0 = input_data[offset_input + i_sample];
31-
for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) {
32-
a0 -= output_data[offset_output + i_sample + i_coeff] *
33-
a_coeff_flipped_data[i_coeff + i_channel * n_order];
26+
scalar_t* output_data =
27+
reinterpret_cast<scalar_t*>(padded_output_waveform.data_ptr());
28+
const scalar_t* input_data =
29+
reinterpret_cast<scalar_t*>(input_signal_windows.data_ptr());
30+
const scalar_t* a_coeff_flipped_data =
31+
reinterpret_cast<scalar_t*>(a_coeff_flipped.data_ptr());
32+
33+
torch::stable::parallel_for(
34+
0, n_channel * n_batch, 1, [&](int64_t begin, int64_t end) {
35+
for (auto i = begin; i < end; i++) {
36+
int64_t offset_input = i * n_samples_input;
37+
int64_t offset_output = i * n_samples_output;
38+
int64_t i_channel = i % n_channel;
39+
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
40+
scalar_t a0 = input_data[offset_input + i_sample];
41+
for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) {
42+
a0 -= output_data[offset_output + i_sample + i_coeff] *
43+
a_coeff_flipped_data[i_coeff + i_channel * n_order];
44+
}
45+
output_data[offset_output + i_sample + n_order - 1] = a0;
46+
}
3447
}
35-
output_data[offset_output + i_sample + n_order - 1] = a0;
36-
}
37-
}
38-
});
48+
});
3949
}
4050

41-
void cpu_lfilter_core_loop(
42-
const torch::Tensor& input_signal_windows,
43-
const torch::Tensor& a_coeff_flipped,
44-
torch::Tensor& padded_output_waveform) {
45-
TORCH_CHECK(
46-
input_signal_windows.device().is_cpu() &&
47-
a_coeff_flipped.device().is_cpu() &&
48-
padded_output_waveform.device().is_cpu());
51+
Tensor cpu_lfilter_core_loop(
52+
Tensor input_signal_windows,
53+
Tensor a_coeff_flipped,
54+
Tensor padded_output_waveform) {
55+
STD_TORCH_CHECK(
56+
input_signal_windows.is_cpu() && a_coeff_flipped.is_cpu() &&
57+
padded_output_waveform.is_cpu());
4958

50-
TORCH_CHECK(
59+
STD_TORCH_CHECK(
5160
input_signal_windows.is_contiguous() && a_coeff_flipped.is_contiguous() &&
5261
padded_output_waveform.is_contiguous());
5362

54-
TORCH_CHECK(
55-
(input_signal_windows.dtype() == torch::kFloat32 ||
56-
input_signal_windows.dtype() == torch::kFloat64) &&
57-
(a_coeff_flipped.dtype() == torch::kFloat32 ||
58-
a_coeff_flipped.dtype() == torch::kFloat64) &&
59-
(padded_output_waveform.dtype() == torch::kFloat32 ||
60-
padded_output_waveform.dtype() == torch::kFloat64));
63+
STD_TORCH_CHECK(
64+
(input_signal_windows.scalar_type() == ScalarType::Float ||
65+
input_signal_windows.scalar_type() == ScalarType::Double) &&
66+
(a_coeff_flipped.scalar_type() == ScalarType::Float ||
67+
a_coeff_flipped.scalar_type() == ScalarType::Double) &&
68+
(padded_output_waveform.scalar_type() == ScalarType::Float ||
69+
padded_output_waveform.scalar_type() == ScalarType::Double));
6170

62-
TORCH_CHECK(input_signal_windows.size(0) == padded_output_waveform.size(0));
63-
TORCH_CHECK(input_signal_windows.size(1) == padded_output_waveform.size(1));
71+
STD_TORCH_CHECK(
72+
input_signal_windows.size(0) == padded_output_waveform.size(0));
73+
STD_TORCH_CHECK(
74+
input_signal_windows.size(1) == padded_output_waveform.size(1));
6475

65-
TORCH_CHECK(
76+
STD_TORCH_CHECK(
6677
input_signal_windows.size(2) + a_coeff_flipped.size(1) - 1 ==
6778
padded_output_waveform.size(2));
6879

69-
AT_DISPATCH_FLOATING_TYPES(
70-
input_signal_windows.scalar_type(), "lfilter_core_loop", [&] {
80+
THO_DISPATCH_V2(
81+
input_signal_windows.scalar_type(),
82+
"lfilter_core_loop",
83+
[&] {
7184
host_lfilter_core_loop<scalar_t>(
7285
input_signal_windows, a_coeff_flipped, padded_output_waveform);
73-
});
86+
},
87+
AT_FLOATING_TYPES);
88+
return padded_output_waveform;
7489
}
7590

76-
void lfilter_core_generic_loop(
77-
const torch::Tensor& input_signal_windows,
78-
const torch::Tensor& a_coeff_flipped,
79-
torch::Tensor& padded_output_waveform) {
91+
Tensor lfilter_core_generic_loop(
92+
Tensor input_signal_windows,
93+
Tensor a_coeff_flipped,
94+
Tensor padded_output_waveform) {
8095
int64_t n_samples_input = input_signal_windows.size(2);
8196
int64_t n_order = a_coeff_flipped.size(1);
82-
auto coeff = a_coeff_flipped.unsqueeze(2);
97+
auto coeff = torchaudio::stable::unsqueeze(a_coeff_flipped, 2);
8398
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
84-
auto windowed_output_signal =
85-
torch::narrow(padded_output_waveform, 2, i_sample, i_sample + n_order)
86-
.transpose(0, 1);
87-
auto o0 = torch::select(input_signal_windows, 2, i_sample) -
88-
at::matmul(windowed_output_signal, coeff).squeeze(2).transpose(0, 1);
89-
padded_output_waveform.index_put_(
90-
{torch::indexing::Slice(),
91-
torch::indexing::Slice(),
92-
i_sample + n_order - 1},
93-
o0);
99+
auto windowed_output_signal = torch::stable::transpose(
100+
torch::stable::narrow(
101+
padded_output_waveform, 2, i_sample, i_sample + n_order),
102+
0,
103+
1);
104+
auto o0 = torchaudio::stable::subtract(
105+
torchaudio::stable::select(input_signal_windows, 2, i_sample),
106+
torch::stable::transpose(
107+
torchaudio::stable::squeeze(
108+
torchaudio::stable::matmul(windowed_output_signal, coeff), 2),
109+
0,
110+
1));
111+
auto s = torchaudio::stable::select(
112+
padded_output_waveform, 2, i_sample + n_order - 1);
113+
torch::stable::copy_(s, o0);
94114
}
115+
return padded_output_waveform;
95116
}
96117

97118
} // namespace
98119

99-
TORCH_LIBRARY(torchaudio, m) {
120+
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
100121
m.def(
101-
"torchaudio::_lfilter_core_loop(Tensor input_signal_windows, Tensor a_coeff_flipped, Tensor(a!) padded_output_waveform) -> ()");
122+
"_lfilter_core_loop("
123+
"Tensor input_signal_windows,"
124+
"Tensor a_coeff_flipped,"
125+
"Tensor(a!) padded_output_waveform) -> Tensor(a!)");
102126
}
103127

104-
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
105-
m.impl("torchaudio::_lfilter_core_loop", &cpu_lfilter_core_loop);
128+
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
129+
m.impl("_lfilter_core_loop", TORCH_BOX(&cpu_lfilter_core_loop));
106130
}
107131

108132
#ifdef USE_CUDA
109-
TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
110-
m.impl("torchaudio::_lfilter_core_loop", &cuda_lfilter_core_loop);
133+
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
134+
m.impl("_lfilter_core_loop", TORCH_BOX(&cuda_lfilter_core_loop));
111135
}
112136
#endif
113137

114-
TORCH_LIBRARY_IMPL(torchaudio, CompositeExplicitAutograd, m) {
115-
m.impl("torchaudio::_lfilter_core_loop", &lfilter_core_generic_loop);
138+
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CompositeExplicitAutograd, m) {
139+
m.impl("_lfilter_core_loop", TORCH_BOX(&lfilter_core_generic_loop));
116140
}

src/libtorchaudio/stable/ops.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,4 +182,51 @@ T item(const Tensor& self) {
182182
}
183183
}
184184

185+
inline Tensor unsqueeze(const Tensor& self, int64_t dim) {
186+
const auto num_args = 2;
187+
std::array<StableIValue, num_args> stack{
188+
torch::stable::detail::from(self), torch::stable::detail::from(dim)};
189+
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
190+
"aten::unsqueeze", "", stack.data(), TORCH_ABI_VERSION));
191+
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
192+
}
193+
194+
inline Tensor select(const Tensor& self, int64_t dim, int64_t index) {
195+
const auto num_args = 3;
196+
std::array<StableIValue, num_args> stack{
197+
torch::stable::detail::from(self),
198+
torch::stable::detail::from(dim),
199+
torch::stable::detail::from(index)};
200+
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
201+
"aten::select", "", stack.data(), TORCH_ABI_VERSION));
202+
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
203+
}
204+
205+
inline Tensor squeeze(const Tensor& self, int64_t dim) {
206+
const auto num_args = 2;
207+
std::array<StableIValue, num_args> stack{
208+
torch::stable::detail::from(self), torch::stable::detail::from(dim)};
209+
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
210+
"aten::squeeze", "dim", stack.data(), TORCH_ABI_VERSION));
211+
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
212+
}
213+
214+
inline Tensor matmul(const Tensor& self, const Tensor& other) {
215+
const auto num_args = 2;
216+
std::array<StableIValue, num_args> stack{
217+
torch::stable::detail::from(self), torch::stable::detail::from(other)};
218+
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
219+
"aten::matmul", "", stack.data(), TORCH_ABI_VERSION));
220+
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
221+
}
222+
223+
inline Tensor subtract(const Tensor& self, const Tensor& other) {
224+
const auto num_args = 2;
225+
std::array<StableIValue, num_args> stack{
226+
torch::stable::detail::from(self), torch::stable::detail::from(other)};
227+
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
228+
"aten::subtract", "Tensor", stack.data(), TORCH_ABI_VERSION));
229+
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
230+
}
231+
185232
} // namespace torchaudio::stable

0 commit comments

Comments
 (0)