diff --git a/.github/scripts/unittest-linux/run_test.sh b/.github/scripts/unittest-linux/run_test.sh index f311c8370e..559b55437a 100755 --- a/.github/scripts/unittest-linux/run_test.sh +++ b/.github/scripts/unittest-linux/run_test.sh @@ -30,5 +30,5 @@ fi ( cd test - pytest torchaudio_unittest -k "not backend and not /io/ and not prototype and not sox and not ffmpeg and not fairseq and not hdemucs" + pytest torchaudio_unittest -k "not backend and not /io/ and not prototype and not sox and not ffmpeg and not fairseq and not hdemucs and not (torchscript and rnnt)" ) diff --git a/src/libtorchaudio/CMakeLists.txt b/src/libtorchaudio/CMakeLists.txt index 713cb50533..85bc227cd6 100644 --- a/src/libtorchaudio/CMakeLists.txt +++ b/src/libtorchaudio/CMakeLists.txt @@ -28,7 +28,6 @@ if(BUILD_RNNT) rnnt/compute_alphas.cpp rnnt/compute_betas.cpp rnnt/compute.cpp - rnnt/autograd.cpp ) if (USE_CUDA) list( diff --git a/src/libtorchaudio/rnnt/autograd.cpp b/src/libtorchaudio/rnnt/autograd.cpp deleted file mode 100644 index dcf68409ed..0000000000 --- a/src/libtorchaudio/rnnt/autograd.cpp +++ /dev/null @@ -1,69 +0,0 @@ -#include - -namespace torchaudio { -namespace rnnt { - -class RNNTLossFunction : public torch::autograd::Function { - public: - static torch::autograd::tensor_list forward( - torch::autograd::AutogradContext* ctx, - torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, - int64_t blank, - double clamp, - bool fused_log_softmax = true) { - torch::Tensor undef; - auto result = rnnt_loss( - logits, - targets, - logit_lengths, - target_lengths, - blank, - clamp, - fused_log_softmax); - auto costs = std::get<0>(result); - auto grads = std::get<1>(result).value_or(undef); - ctx->save_for_backward({grads}); - return {costs, grads}; - } - - static torch::autograd::tensor_list backward( - torch::autograd::AutogradContext* ctx, - torch::autograd::tensor_list grad_outputs) { - auto saved = ctx->get_saved_variables(); - auto grad = saved[0]; - auto grad_out = grad_outputs[0].view({-1, 1, 1, 1}); - auto result = grad * grad_out; - torch::Tensor undef; - return {result, undef, undef, undef, undef, undef, undef, undef}; - } -}; - -std::tuple> rnnt_loss_autograd( - torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, - int64_t blank, - double clamp, - bool fused_log_softmax = true) { - at::AutoDispatchBelowADInplaceOrView guard; - auto results = RNNTLossFunction::apply( - logits, - targets, - logit_lengths, - target_lengths, - blank, - clamp, - fused_log_softmax); - return std::make_tuple(results[0], results[1]); -} - -TORCH_LIBRARY_IMPL(torchaudio, Autograd, m) { - m.impl("rnnt_loss", rnnt_loss_autograd); -} - -} // namespace rnnt -} // namespace torchaudio diff --git a/src/libtorchaudio/rnnt/compute.cpp b/src/libtorchaudio/rnnt/compute.cpp index 567c9b5d4b..867542e4e7 100644 --- a/src/libtorchaudio/rnnt/compute.cpp +++ b/src/libtorchaudio/rnnt/compute.cpp @@ -1,27 +1,8 @@ -#include +#include +#include -std::tuple> rnnt_loss( - torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, - int64_t blank, - double clamp, - bool fused_log_softmax = true) { - static auto op = torch::Dispatcher::singleton() - .findSchemaOrThrow("torchaudio::rnnt_loss", "") - .typed(); - return op.call( - logits, - targets, - logit_lengths, - target_lengths, - blank, - clamp, - fused_log_softmax); -} -TORCH_LIBRARY_FRAGMENT(torchaudio, m) { +STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) { m.def( "rnnt_loss(Tensor logits," "Tensor targets," @@ -29,5 +10,5 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { "Tensor target_lengths," "int blank," "float clamp," - "bool fused_log_softmax) -> (Tensor, Tensor?)"); + "bool fused_log_softmax) -> (Tensor, Tensor)"); } diff --git a/src/libtorchaudio/rnnt/compute.h b/src/libtorchaudio/rnnt/compute.h deleted file mode 100644 index ed2dd0c37e..0000000000 --- a/src/libtorchaudio/rnnt/compute.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include - -std::tuple> rnnt_loss( - torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, - int64_t blank, - double clamp, - bool fused_log_softmax); diff --git a/src/libtorchaudio/rnnt/compute_alphas.cpp b/src/libtorchaudio/rnnt/compute_alphas.cpp index adbcc1c8e7..dd187f9777 100644 --- a/src/libtorchaudio/rnnt/compute_alphas.cpp +++ b/src/libtorchaudio/rnnt/compute_alphas.cpp @@ -1,6 +1,6 @@ -#include +#include -TORCH_LIBRARY_FRAGMENT(torchaudio, m) { +STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) { m.def( "rnnt_loss_alphas(Tensor logits," "Tensor targets," diff --git a/src/libtorchaudio/rnnt/compute_betas.cpp b/src/libtorchaudio/rnnt/compute_betas.cpp index 7728838137..b1cd379a66 100644 --- a/src/libtorchaudio/rnnt/compute_betas.cpp +++ b/src/libtorchaudio/rnnt/compute_betas.cpp @@ -1,6 +1,6 @@ -#include +#include -TORCH_LIBRARY_FRAGMENT(torchaudio, m) { +STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) { m.def( "rnnt_loss_betas(Tensor logits," "Tensor targets," diff --git a/src/libtorchaudio/rnnt/cpu/compute.cpp b/src/libtorchaudio/rnnt/cpu/compute.cpp index 097b4bd7e1..817f79ef99 100644 --- a/src/libtorchaudio/rnnt/cpu/compute.cpp +++ b/src/libtorchaudio/rnnt/cpu/compute.cpp @@ -1,148 +1,212 @@ #include -#include +#include +#include +#include namespace torchaudio { namespace rnnt { namespace cpu { +using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle; + // Entry point into RNNT Loss -std::tuple> compute( - torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, +std::tuple compute( + const RAIIATH logits, + const RAIIATH targets, + const RAIIATH logit_lengths, + const RAIIATH target_lengths, int64_t blank, double clamp, bool fused_log_softmax = true) { - TORCH_CHECK( - logits.device().type() == targets.device().type(), - "logits and targets must be on the same device"); - TORCH_CHECK( - logits.device().type() == logit_lengths.device().type(), - "logits and logit_lengths must be on the same device"); - TORCH_CHECK( - logits.device().type() == target_lengths.device().type(), - "logits and target_lengths must be on the same device"); - - TORCH_CHECK( - logits.dtype() == torch::kFloat32 || logits.dtype() == torch::kFloat16, - "logits must be float32 or float16 (half) type"); - TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type"); - TORCH_CHECK( - logit_lengths.dtype() == torch::kInt32, - "logit_lengths must be int32 type"); - TORCH_CHECK( - target_lengths.dtype() == torch::kInt32, - "target_lengths must be int32 type"); - - TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous"); - TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); - TORCH_CHECK( - logit_lengths.is_contiguous(), "logit_lengths must be contiguous"); - TORCH_CHECK( - target_lengths.is_contiguous(), "target_lengths must be contiguous"); - - TORCH_CHECK( - logits.dim() == 4, "logits must be 4-D (batch, time, target, class)"); - TORCH_CHECK( - targets.dim() == 2, "targets must be 2-D (batch, max target length)"); - TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D"); - TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D"); - - TORCH_CHECK( - logit_lengths.size(0) == logits.size(0), - "batch dimension mismatch between logits and logit_lengths"); - TORCH_CHECK( - target_lengths.size(0) == logits.size(0), - "batch dimension mismatch between logits and target_lengths"); - TORCH_CHECK( - targets.size(0) == logits.size(0), - "batch dimension mismatch between logits and targets"); - - TORCH_CHECK( - blank >= 0 && blank < logits.size(-1), - "blank must be within [0, logits.shape[-1])"); - - TORCH_CHECK( - logits.size(1) == at::max(logit_lengths).item().toInt(), - "input length mismatch"); - TORCH_CHECK( - logits.size(2) == at::max(target_lengths).item().toInt() + 1, - "output length mismatch"); - TORCH_CHECK( - targets.size(1) == at::max(target_lengths).item().toInt(), - "target length mismatch"); + + int32_t logits_device; + aoti_torch_get_device_type(logits.get(), &logits_device); + int32_t targets_device; + aoti_torch_get_device_type(targets.get(), &targets_device); + int32_t logit_lengths_device; + aoti_torch_get_device_type(logit_lengths.get(), &logit_lengths_device); + int32_t target_lengths_device; + aoti_torch_get_device_type(target_lengths.get(), &target_lengths_device); + + AOTI_TORCH_CHECK(logits_device == targets_device); + AOTI_TORCH_CHECK(logits_device == logit_lengths_device); + AOTI_TORCH_CHECK(logits_device == target_lengths_device); + + int32_t logits_dtype; + aoti_torch_get_dtype(logits.get(), &logits_dtype); + AOTI_TORCH_CHECK(logits_dtype == aoti_torch_dtype_float32() || + logits_dtype == aoti_torch_dtype_float16()); + + int32_t targets_dtype; + aoti_torch_get_dtype(targets.get(), &targets_dtype); + AOTI_TORCH_CHECK(targets_dtype == aoti_torch_dtype_int32() || + logits_dtype == aoti_torch_dtype_float16()); + + int32_t logit_lengths_dtype; + aoti_torch_get_dtype(logit_lengths.get(), &logit_lengths_dtype); + AOTI_TORCH_CHECK(logit_lengths_dtype == aoti_torch_dtype_int32() || + logit_lengths_dtype == aoti_torch_dtype_float16()); + + int32_t target_lengths_dtype; + aoti_torch_get_dtype(target_lengths.get(), &target_lengths_dtype); + AOTI_TORCH_CHECK(target_lengths_dtype == aoti_torch_dtype_int32() || + target_lengths_dtype == aoti_torch_dtype_float16()); + + bool bool_tmp; + aoti_torch_is_contiguous(logits.get(), &bool_tmp); + AOTI_TORCH_CHECK(bool_tmp); + aoti_torch_is_contiguous(targets.get(), &bool_tmp); + AOTI_TORCH_CHECK(bool_tmp); + aoti_torch_is_contiguous(logit_lengths.get(), &bool_tmp); + AOTI_TORCH_CHECK(bool_tmp); + aoti_torch_is_contiguous(target_lengths.get(), &bool_tmp); + + int64_t int_tmp; + aoti_torch_get_dim(logits.get(), &int_tmp); + AOTI_TORCH_CHECK(int_tmp == 4); + aoti_torch_get_dim(targets.get(), &int_tmp); + AOTI_TORCH_CHECK(int_tmp == 2); + aoti_torch_get_dim(logit_lengths.get(), &int_tmp); + AOTI_TORCH_CHECK(int_tmp == 1); + aoti_torch_get_dim(target_lengths.get(), &int_tmp); + AOTI_TORCH_CHECK(int_tmp == 1); + + int64_t logit_lengths_size; + aoti_torch_get_size(logit_lengths.get(), 0, &logit_lengths_size); + int64_t logits_size; + aoti_torch_get_size(logits.get(), 0, &logits_size); + AOTI_TORCH_CHECK(logit_lengths_size == logits_size); + int64_t target_lengths_size; + aoti_torch_get_size(target_lengths.get(), 0, &target_lengths_size); + AOTI_TORCH_CHECK(target_lengths_size == logits_size); + int64_t targets_size; + aoti_torch_get_size(targets.get(), 0, &targets_size); + AOTI_TORCH_CHECK(targets_size == logits_size); + + // TORCH_CHECK( + // blank >= 0 && blank < logits.size(-1), + // "blank must be within [0, logits.shape[-1])"); + + // TORCH_CHECK( + // logits.size(1) == at::max(logit_lengths).item().toInt(), + // "input length mismatch"); + // TORCH_CHECK( + // logits.size(2) == at::max(target_lengths).item().toInt() + 1, + // "output length mismatch"); + // TORCH_CHECK( + // targets.size(1) == at::max(target_lengths).item().toInt(), + // "target length mismatch"); Options options; - options.batchSize_ = logit_lengths.size(0); - options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); - options.maxSrcLen_ = logits.size(1); - options.maxTgtLen_ = logits.size(2); - options.numTargets_ = logits.size(3); + options.batchSize_ = (int)logit_lengths_size; + options.nHypos_ = (int)target_lengths_size; + options.nHypos_ /= options.batchSize_; + aoti_torch_get_size(logits.get(), 1, &int_tmp); + options.maxSrcLen_ = (int)int_tmp; + aoti_torch_get_size(logits.get(), 2, &int_tmp); + options.maxTgtLen_ = (int)int_tmp; + aoti_torch_get_size(logits.get(), 3, &int_tmp); + options.numTargets_ = (int)int_tmp; options.blank_ = blank; options.clamp_ = clamp; options.fusedLogSmax_ = fused_log_softmax; - TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CPU); + AOTI_TORCH_CHECK(logits_device == aoti_torch_device_type_cpu()); options.device_ = CPU; - torch::Tensor costs = torch::empty( - options.batchSize_ * options.nHypos_, - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); - std::optional gradients = torch::zeros_like(logits); - - torch::Tensor int_workspace = torch::empty( - IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Int)); - - torch::Tensor float_workspace = torch::empty( - DtypeWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Float)); + int32_t logits_device_index; + aoti_torch_get_device_index(logits.get(), &logits_device_index); + int64_t cost_sizes[1] = {options.batchSize_ * options.nHypos_}; + int64_t stride1[1] = {1}; + AtenTensorHandle costs; + aoti_torch_empty_strided(1, cost_sizes, stride1, logits_dtype, logits_device, logits_device_index, &costs); + + AtenTensorHandle gradients; + aoti_torch_clone(logits.get(), &gradients); + aoti_torch_zero_(gradients); + + AtenTensorHandle int_workspace; + int64_t sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; + int64_t strides[1] = {1}; + aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); + + AtenTensorHandle float_workspace; + aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace); + + int64_t float_numel; + aoti_torch_get_numel(float_workspace, &float_numel); + void *int_workspace_ptr; + aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr); + void *float_workspace_ptr; + aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr); + int64_t int_numel; + aoti_torch_get_numel(int_workspace, &int_numel); Workspace workspace( /*options=*/options, - /*dtype_data=*/float_workspace.data_ptr(), - /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data_ptr(), - /*int_size=*/int_workspace.numel()); + /*dtype_data=*/(float*)float_workspace_ptr, + /*dtype_size=*/float_numel, + /*int_data=*/(int*)int_workspace_ptr, + /*int_size=*/int_numel); + + void *logit_ptr; + aoti_torch_get_data_ptr(logits.get(), &logit_ptr); + + void *target_ptr; + aoti_torch_get_data_ptr(targets.get(), &target_ptr); + + void *logit_len_ptr; + aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr); - switch (logits.scalar_type()) { - case torch::ScalarType::Float: { + void *target_len_ptr; + aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr); + + void *costs_ptr; + aoti_torch_get_data_ptr(costs, &costs_ptr); + + void *grads_ptr; + aoti_torch_get_data_ptr(gradients, &grads_ptr); + + if (logits_dtype == aoti_torch_dtype_float32()) { Compute( /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*logit_lengths=*/logit_lengths.data_ptr(), - /*target_lengths=*/target_lengths.data_ptr(), - /*costs=*/costs.data_ptr(), - /*gradients=*/gradients->data_ptr()); - break; - } - case torch::ScalarType::Half: { + /*logits=*/(float*)logit_ptr, + /*targets=*/(int*)target_ptr, + /*logit_lengths=*/(int*)logit_len_ptr, + /*target_lengths=*/(int*)target_len_ptr, + /*costs=*/(float*)costs_ptr, + /*gradients=*/(float*)grads_ptr); + } else { Compute( /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*logit_lengths=*/logit_lengths.data_ptr(), - /*target_lengths=*/target_lengths.data_ptr(), - /*costs=*/costs.data_ptr(), - /*gradients=*/gradients->data_ptr()); - break; + /*logits=*/(c10::Half*)logit_ptr, + /*targets=*/(int*)target_ptr, + /*logit_lengths=*/(int*)logit_len_ptr, + /*target_lengths=*/(int*)target_len_ptr, + /*costs=*/(c10::Half*)costs_ptr, + /*gradients=*/(c10::Half*)grads_ptr); } - default: { - break; - } - }; - return std::make_tuple(costs, gradients); + return std::make_tuple(RAIIATH(costs), RAIIATH(gradients)); +} + +void boxed_compute(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + RAIIATH t1(to(stack[0])); + RAIIATH t2(to(stack[1])); + RAIIATH t3(to(stack[2])); + RAIIATH t4(to(stack[3])); + int64_t blank = to(stack[4]); + double clamp = to(stack[5]); + bool fused_log_softmax = to(stack[6]); + auto result = compute( + std::move(t1), std::move(t2), std::move(t3), std::move(t4), + blank, clamp, fused_log_softmax); + stack[0] = from((std::get<0>(result)).release()); + stack[1] = from((std::get<1>(result)).release()); } -TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { - m.impl("rnnt_loss", &compute); +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("rnnt_loss", &boxed_compute); } } // namespace cpu diff --git a/src/libtorchaudio/rnnt/cpu/compute_alphas.cpp b/src/libtorchaudio/rnnt/cpu/compute_alphas.cpp index 6923cbe5d8..40ed538175 100644 --- a/src/libtorchaudio/rnnt/cpu/compute_alphas.cpp +++ b/src/libtorchaudio/rnnt/cpu/compute_alphas.cpp @@ -1,68 +1,126 @@ #include -#include +#include +#include +#include + +// TODO: +// Are the StableIValue AtenTensorHandles reference counted at all? +// Why do we call release() on returned arguments? namespace torchaudio { namespace rnnt { namespace cpu { -torch::Tensor compute_alphas( - const torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, +using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle; + +RAIIATH compute_alphas( + const RAIIATH logits, + const RAIIATH targets, + const RAIIATH logit_lengths, + const RAIIATH target_lengths, int64_t blank, double clamp) { Options options; - options.batchSize_ = logit_lengths.size(0); - options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); - options.maxSrcLen_ = logits.size(1); - options.maxTgtLen_ = logits.size(2); - options.numTargets_ = logits.size(3); + int64_t tmp; + aoti_torch_get_size(logit_lengths.get(), 0, &tmp); + options.batchSize_ = (int)tmp; + aoti_torch_get_size(target_lengths.get(), 0, &tmp); + options.nHypos_ = (int)tmp; + options.nHypos_ /= options.batchSize_; + aoti_torch_get_size(logits.get(), 1, &tmp); + options.maxSrcLen_ = (int)tmp; + aoti_torch_get_size(logits.get(), 2, &tmp); + options.maxTgtLen_ = (int)tmp; + aoti_torch_get_size(logits.get(), 3, &tmp); + options.numTargets_ = (int)tmp; options.blank_ = blank; options.clamp_ = clamp; - TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CPU); + int32_t logits_device_type; + aoti_torch_get_device_type(logits.get(), &logits_device_type); + AOTI_TORCH_CHECK(logits_device_type == aoti_torch_device_type_cpu()); + options.device_ = CPU; - torch::Tensor alphas = torch::zeros( - {options.batchSize_ * options.nHypos_, - options.maxSrcLen_, - options.maxTgtLen_}, - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + int32_t logits_device; + aoti_torch_get_device_type(logits.get(), &logits_device); + int32_t logits_device_index; + aoti_torch_get_device_index(logits.get(), &logits_device_index); + int32_t logits_dtype; + aoti_torch_get_dtype(logits.get(), &logits_dtype); + + int64_t param_sizes[3] = {options.batchSize_ * options.nHypos_, options.maxSrcLen_, options.maxTgtLen_}; + int64_t param_strides[3] = {options.maxSrcLen_ * options.maxTgtLen_, options.maxTgtLen_, 1}; - torch::Tensor int_workspace = torch::empty( - IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Int)); + AtenTensorHandle alphas; + aoti_torch_empty_strided(3, param_sizes, param_strides, logits_dtype, logits_device, logits_device_index, &alphas); + aoti_torch_zero_(alphas); - torch::Tensor float_workspace = torch::empty( - DtypeWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Float)); + AtenTensorHandle int_workspace; + int64_t sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; + int64_t strides[1] = {1}; + aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); + + AtenTensorHandle float_workspace; + aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace); + + int64_t float_numel; + aoti_torch_get_numel(float_workspace, &float_numel); + void *int_workspace_ptr; + aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr); + void *float_workspace_ptr; + aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr); + int64_t int_numel; + aoti_torch_get_numel(int_workspace, &int_numel); Workspace workspace( /*options=*/options, - /*dtype_data=*/float_workspace.data_ptr(), - /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data_ptr(), - /*int_size=*/int_workspace.numel()); + /*dtype_data=*/(float*)float_workspace_ptr, + /*dtype_size=*/float_numel, + /*int_data=*/(int*)int_workspace_ptr, + /*int_size=*/int_numel); + + void *logit_ptr; + aoti_torch_get_data_ptr(logits.get(), &logit_ptr); + + void *target_ptr; + aoti_torch_get_data_ptr(targets.get(), &target_ptr); + + void *logit_len_ptr; + aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr); + + void *target_len_ptr; + aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr); + + void *alpha_ptr; + aoti_torch_get_data_ptr(alphas, &alpha_ptr); // Only support float, this is mainly to enable easy // unit-testing ComputeAlphas( /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*logit_lengths=*/logit_lengths.data_ptr(), - /*target_lengths=*/target_lengths.data_ptr(), - /*alphas=*/alphas.data_ptr()); - return alphas; + /*logits=*/(float*)logit_ptr, + /*targets=*/(int*)target_ptr, + /*logit_lengths=*/(int*)logit_len_ptr, + /*target_lengths=*/(int*)target_len_ptr, + /*alphas=*/(float*)alpha_ptr); + return RAIIATH(alphas); +} + +void boxed_compute_alphas(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + RAIIATH t1(to(stack[0])); + RAIIATH t2(to(stack[1])); + RAIIATH t3(to(stack[2])); + RAIIATH t4(to(stack[3])); + int64_t blank = to(stack[4]); + double clamp = to(stack[5]); + RAIIATH result = compute_alphas(std::move(t1), std::move(t2), std::move(t3), std::move(t4), + blank, clamp); + stack[0] = from(result.release()); } -TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { - m.impl("rnnt_loss_alphas", &compute_alphas); +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("rnnt_loss_alphas", &boxed_compute_alphas); } } // namespace cpu diff --git a/src/libtorchaudio/rnnt/cpu/compute_betas.cpp b/src/libtorchaudio/rnnt/cpu/compute_betas.cpp index d812ef34c3..729e86a722 100644 --- a/src/libtorchaudio/rnnt/cpu/compute_betas.cpp +++ b/src/libtorchaudio/rnnt/cpu/compute_betas.cpp @@ -1,73 +1,130 @@ #include #include +#include +#include +#include namespace torchaudio { namespace rnnt { namespace cpu { -torch::Tensor compute_betas( - const torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, +using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle; + +RAIIATH compute_betas( + const RAIIATH logits, + const RAIIATH targets, + const RAIIATH logit_lengths, + const RAIIATH target_lengths, int64_t blank, double clamp) { Options options; - options.batchSize_ = logit_lengths.size(0); - options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); - options.maxSrcLen_ = logits.size(1); - options.maxTgtLen_ = logits.size(2); - options.numTargets_ = logits.size(3); + int64_t tmp; + aoti_torch_get_size(logit_lengths.get(), 0, &tmp); + options.batchSize_ = (int)tmp; + aoti_torch_get_size(target_lengths.get(), 0, &tmp); + options.nHypos_ = (int)tmp; + options.nHypos_ /= options.batchSize_; + aoti_torch_get_size(logits.get(), 1, &tmp); + options.maxSrcLen_ = (int)tmp; + aoti_torch_get_size(logits.get(), 2, &tmp); + options.maxTgtLen_ = (int)tmp; + aoti_torch_get_size(logits.get(), 3, &tmp); + options.numTargets_ = (int)tmp; options.blank_ = blank; options.clamp_ = clamp; - TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CPU); + int32_t logits_device_type; + aoti_torch_get_device_type(logits.get(), &logits_device_type); + AOTI_TORCH_CHECK(logits_device_type == aoti_torch_device_type_cpu()); + options.device_ = CPU; - torch::Tensor costs = torch::empty( - target_lengths.size(0), - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + int32_t logits_device; + aoti_torch_get_device_type(logits.get(), &logits_device); + int32_t logits_device_index; + aoti_torch_get_device_index(logits.get(), &logits_device_index); + int32_t logits_dtype; + aoti_torch_get_dtype(logits.get(), &logits_dtype); + + int64_t cost_sizes[1] = {options.batchSize_}; + int64_t stride1[1] = {1}; + AtenTensorHandle costs; + aoti_torch_empty_strided(1, cost_sizes, stride1, logits_dtype, logits_device, logits_device_index, &costs); + + int64_t betas_sizes[3] = {options.batchSize_ * options.nHypos_, options.maxSrcLen_, options.maxTgtLen_}; + int64_t betas_strides[3] = {options.maxSrcLen_ * options.maxTgtLen_, options.maxTgtLen_, 1}; + AtenTensorHandle betas; + aoti_torch_empty_strided(3, betas_sizes, betas_strides, logits_dtype, logits_device, logits_device_index, &betas); - torch::Tensor betas = torch::zeros( - {options.batchSize_ * options.nHypos_, - options.maxSrcLen_, - options.maxTgtLen_}, - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + AtenTensorHandle int_workspace; + int64_t w_sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; + aoti_torch_empty_strided(1, w_sizes, stride1, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); - torch::Tensor int_workspace = torch::empty( - IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Int)); + AtenTensorHandle float_workspace; + aoti_torch_empty_strided(1, w_sizes, stride1, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace); - torch::Tensor float_workspace = torch::empty( - DtypeWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Float)); + int64_t float_numel; + aoti_torch_get_numel(float_workspace, &float_numel); + void *int_workspace_ptr; + aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr); + void *float_workspace_ptr; + aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr); + int64_t int_numel; + aoti_torch_get_numel(int_workspace, &int_numel); Workspace workspace( /*options=*/options, - /*dtype_data=*/float_workspace.data_ptr(), - /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data_ptr(), - /*int_size=*/int_workspace.numel()); + /*dtype_data=*/(float*)float_workspace_ptr, + /*dtype_size=*/float_numel, + /*int_data=*/(int*)int_workspace_ptr, + /*int_size=*/int_numel); + + void *logit_ptr; + aoti_torch_get_data_ptr(logits.get(), &logit_ptr); + + void *target_ptr; + aoti_torch_get_data_ptr(targets.get(), &target_ptr); + + void *logit_len_ptr; + aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr); + + void *target_len_ptr; + aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr); + + void *beta_ptr; + aoti_torch_get_data_ptr(betas, &beta_ptr); + + void *cost_ptr; + aoti_torch_get_data_ptr(costs, &cost_ptr); // Only support float, this is mainly to enable easy // unit-testing ComputeBetas( /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*logit_lengths=*/logit_lengths.data_ptr(), - /*target_lengths=*/target_lengths.data_ptr(), - /*costs=*/costs.data_ptr(), - /*betas=*/betas.data_ptr()); - return betas; + /*logits=*/(float*)logit_ptr, + /*targets=*/(int*)target_ptr, + /*logit_lengths=*/(int*)logit_len_ptr, + /*target_lengths=*/(int*)target_len_ptr, + /*costs=*/(float*)cost_ptr, + /*betas=*/(float*)beta_ptr); + return RAIIATH(betas); +} + + +void boxed_compute_betas(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + RAIIATH t1(to(stack[0])); + RAIIATH t2(to(stack[1])); + RAIIATH t3(to(stack[2])); + RAIIATH t4(to(stack[3])); + int64_t blank = to(stack[4]); + double clamp = to(stack[5]); + RAIIATH result = compute_betas(std::move(t1), std::move(t2), std::move(t3), std::move(t4), + blank, clamp); + stack[0] = from(result.release()); } -TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { - m.impl("rnnt_loss_betas", &compute_betas); +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("rnnt_loss_betas", &boxed_compute_betas); } } // namespace cpu diff --git a/src/libtorchaudio/rnnt/gpu/compute.cu b/src/libtorchaudio/rnnt/gpu/compute.cu index 43dae68027..1073b18a81 100644 --- a/src/libtorchaudio/rnnt/gpu/compute.cu +++ b/src/libtorchaudio/rnnt/gpu/compute.cu @@ -1,151 +1,216 @@ #include #include -#include +#include +#include +#include namespace torchaudio { namespace rnnt { namespace gpu { +using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle; + // Entry point into RNNT Loss -std::tuple> compute( - torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, +std::tuple compute( + const RAIIATH logits, + const RAIIATH targets, + const RAIIATH logit_lengths, + const RAIIATH target_lengths, int64_t blank, double clamp, bool fused_log_softmax = true) { - TORCH_CHECK( - logits.device().type() == targets.device().type(), - "logits and targets must be on the same device"); - TORCH_CHECK( - logits.device().type() == logit_lengths.device().type(), - "logits and logit_lengths must be on the same device"); - TORCH_CHECK( - logits.device().type() == target_lengths.device().type(), - "logits and target_lengths must be on the same device"); - - TORCH_CHECK( - logits.dtype() == torch::kFloat32 || logits.dtype() == torch::kFloat16, - "logits must be float32 or float16 (half) type"); - TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type"); - TORCH_CHECK( - logit_lengths.dtype() == torch::kInt32, - "logit_lengths must be int32 type"); - TORCH_CHECK( - target_lengths.dtype() == torch::kInt32, - "target_lengths must be int32 type"); - - TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous"); - TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); - TORCH_CHECK( - logit_lengths.is_contiguous(), "logit_lengths must be contiguous"); - TORCH_CHECK( - target_lengths.is_contiguous(), "target_lengths must be contiguous"); - - TORCH_CHECK( - logits.dim() == 4, "logits must be 4-D (batch, time, target, class)"); - TORCH_CHECK( - targets.dim() == 2, "targets must be 2-D (batch, max target length)"); - TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D"); - TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D"); - - TORCH_CHECK( - logit_lengths.size(0) == logits.size(0), - "batch dimension mismatch between logits and logit_lengths"); - TORCH_CHECK( - target_lengths.size(0) == logits.size(0), - "batch dimension mismatch between logits and target_lengths"); - TORCH_CHECK( - targets.size(0) == logits.size(0), - "batch dimension mismatch between logits and targets"); - - TORCH_CHECK( - blank >= 0 && blank < logits.size(-1), - "blank must be within [0, logits.shape[-1])"); - - TORCH_CHECK( - logits.size(1) == at::max(logit_lengths).item().toInt(), - "input length mismatch"); - TORCH_CHECK( - logits.size(2) == at::max(target_lengths).item().toInt() + 1, - "output length mismatch"); - TORCH_CHECK( - targets.size(1) == at::max(target_lengths).item().toInt(), - "target length mismatch"); - - Options options; - options.batchSize_ = logit_lengths.size(0); - options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); - options.maxSrcLen_ = logits.size(1); - options.maxTgtLen_ = logits.size(2); - options.numTargets_ = logits.size(3); - options.blank_ = blank; - options.clamp_ = clamp; - options.fusedLogSmax_ = fused_log_softmax; - - TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA); - options.stream_ = at::cuda::getCurrentCUDAStream(); - cudaSetDevice(logits.get_device()); + + int32_t logits_device; + aoti_torch_get_device_type(logits.get(), &logits_device); + int32_t targets_device; + aoti_torch_get_device_type(targets.get(), &targets_device); + int32_t logit_lengths_device; + aoti_torch_get_device_type(logit_lengths.get(), &logit_lengths_device); + int32_t target_lengths_device; + aoti_torch_get_device_type(target_lengths.get(), &target_lengths_device); + + AOTI_TORCH_CHECK(logits_device == targets_device); + AOTI_TORCH_CHECK(logits_device == logit_lengths_device); + AOTI_TORCH_CHECK(logits_device == target_lengths_device); + + int32_t logits_dtype; + aoti_torch_get_dtype(logits.get(), &logits_dtype); + AOTI_TORCH_CHECK(logits_dtype == aoti_torch_dtype_float32() || + logits_dtype == aoti_torch_dtype_float16()); + + int32_t targets_dtype; + aoti_torch_get_dtype(targets.get(), &targets_dtype); + AOTI_TORCH_CHECK(targets_dtype == aoti_torch_dtype_int32() || + logits_dtype == aoti_torch_dtype_float16()); + + int32_t logit_lengths_dtype; + aoti_torch_get_dtype(logit_lengths.get(), &logit_lengths_dtype); + AOTI_TORCH_CHECK(logit_lengths_dtype == aoti_torch_dtype_int32() || + logit_lengths_dtype == aoti_torch_dtype_float16()); + + int32_t target_lengths_dtype; + aoti_torch_get_dtype(target_lengths.get(), &target_lengths_dtype); + AOTI_TORCH_CHECK(target_lengths_dtype == aoti_torch_dtype_int32() || + target_lengths_dtype == aoti_torch_dtype_float16()); + + bool bool_tmp; + aoti_torch_is_contiguous(logits.get(), &bool_tmp); + AOTI_TORCH_CHECK(bool_tmp); + aoti_torch_is_contiguous(targets.get(), &bool_tmp); + AOTI_TORCH_CHECK(bool_tmp); + aoti_torch_is_contiguous(logit_lengths.get(), &bool_tmp); + AOTI_TORCH_CHECK(bool_tmp); + aoti_torch_is_contiguous(target_lengths.get(), &bool_tmp); + + int64_t int_tmp; + aoti_torch_get_dim(logits.get(), &int_tmp); + AOTI_TORCH_CHECK(int_tmp == 4); + aoti_torch_get_dim(targets.get(), &int_tmp); + AOTI_TORCH_CHECK(int_tmp == 2); + aoti_torch_get_dim(logit_lengths.get(), &int_tmp); + AOTI_TORCH_CHECK(int_tmp == 1); + aoti_torch_get_dim(target_lengths.get(), &int_tmp); + AOTI_TORCH_CHECK(int_tmp == 1); + + int64_t logit_lengths_size; + aoti_torch_get_size(logit_lengths.get(), 0, &logit_lengths_size); + int64_t logits_size; + aoti_torch_get_size(logits.get(), 0, &logits_size); + AOTI_TORCH_CHECK(logit_lengths_size == logits_size); + int64_t target_lengths_size; + aoti_torch_get_size(target_lengths.get(), 0, &target_lengths_size); + AOTI_TORCH_CHECK(target_lengths_size == logits_size); + int64_t targets_size; + aoti_torch_get_size(targets.get(), 0, &targets_size); + AOTI_TORCH_CHECK(targets_size == logits_size); + + // TORCH_CHECK( + // blank >= 0 && blank < logits.size(-1), + // "blank must be within [0, logits.shape[-1])"); + + // TORCH_CHECK( + // logits.size(1) == at::max(logit_lengths).item().toInt(), + // "input length mismatch"); + // TORCH_CHECK( + // logits.size(2) == at::max(target_lengths).item().toInt() + 1, + // "output length mismatch"); + // TORCH_CHECK( + // targets.size(1) == at::max(target_lengths).item().toInt(), + // "target length mismatch"); + + Options options; + options.batchSize_ = (int)logit_lengths_size; + options.nHypos_ = (int)target_lengths_size; + options.nHypos_ /= options.batchSize_; + aoti_torch_get_size(logits.get(), 1, &int_tmp); + options.maxSrcLen_ = (int)int_tmp; + aoti_torch_get_size(logits.get(), 2, &int_tmp); + options.maxTgtLen_ = (int)int_tmp; + aoti_torch_get_size(logits.get(), 3, &int_tmp); + options.numTargets_ = (int)int_tmp; + options.blank_ = blank; + options.clamp_ = clamp; + options.fusedLogSmax_ = fused_log_softmax; + + int32_t logits_device_index; + aoti_torch_get_device_index(logits.get(), &logits_device_index); + + TORCH_CHECK_EQ(logits_device, aoti_torch_device_type_cuda()); + aoti_torch_get_current_cuda_stream(logits_device_index, (void**)&options.stream_); + cudaSetDevice(logits_device); options.device_ = GPU; - torch::Tensor costs = torch::empty( - options.batchSize_ * options.nHypos_, - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); - std::optional gradients = torch::zeros_like(logits); + int64_t cost_sizes[1] = {options.batchSize_ * options.nHypos_}; + int64_t stride1[1] = {1}; + AtenTensorHandle costs; + aoti_torch_empty_strided(1, cost_sizes, stride1, logits_dtype, logits_device, logits_device_index, &costs); + + AtenTensorHandle gradients; + aoti_torch_clone(logits.get(), &gradients); + aoti_torch_zero_(gradients); + + AtenTensorHandle int_workspace; + int64_t sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; + int64_t strides[1] = {1}; + aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); - torch::Tensor int_workspace = torch::empty( - IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Int)); + AtenTensorHandle float_workspace; + aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace); - torch::Tensor float_workspace = torch::empty( - DtypeWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Float)); + int64_t float_numel; + aoti_torch_get_numel(float_workspace, &float_numel); + void *int_workspace_ptr; + aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr); + void *float_workspace_ptr; + aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr); + int64_t int_numel; + aoti_torch_get_numel(int_workspace, &int_numel); Workspace workspace( /*options=*/options, - /*dtype_data=*/float_workspace.data_ptr(), - /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data_ptr(), - /*int_size=*/int_workspace.numel()); + /*dtype_data=*/(float*)float_workspace_ptr, + /*dtype_size=*/float_numel, + /*int_data=*/(int*)int_workspace_ptr, + /*int_size=*/int_numel); + + void *logit_ptr; + aoti_torch_get_data_ptr(logits.get(), &logit_ptr); + + void *target_ptr; + aoti_torch_get_data_ptr(targets.get(), &target_ptr); + + void *logit_len_ptr; + aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr); - switch (logits.scalar_type()) { - case torch::ScalarType::Float: { + void *target_len_ptr; + aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr); + + void *costs_ptr; + aoti_torch_get_data_ptr(costs, &costs_ptr); + + void *grads_ptr; + aoti_torch_get_data_ptr(gradients, &grads_ptr); + + if (logits_dtype == aoti_torch_dtype_float32()) { Compute( /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*logit_lengths=*/logit_lengths.data_ptr(), - /*target_lengths=*/target_lengths.data_ptr(), - /*costs=*/costs.data_ptr(), - /*gradients=*/gradients->data_ptr()); - break; - } - case torch::ScalarType::Half: { + /*logits=*/(float*)logit_ptr, + /*targets=*/(int*)target_ptr, + /*logit_lengths=*/(int*)logit_len_ptr, + /*target_lengths=*/(int*)target_len_ptr, + /*costs=*/(float*)costs_ptr, + /*gradients=*/(float*)grads_ptr); + } else { Compute( /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*logit_lengths=*/logit_lengths.data_ptr(), - /*target_lengths=*/target_lengths.data_ptr(), - /*costs=*/costs.data_ptr(), - /*gradients=*/gradients->data_ptr()); - break; + /*logits=*/(c10::Half*)logit_ptr, + /*targets=*/(int*)target_ptr, + /*logit_lengths=*/(int*)logit_len_ptr, + /*target_lengths=*/(int*)target_len_ptr, + /*costs=*/(c10::Half*)costs_ptr, + /*gradients=*/(c10::Half*)grads_ptr); } - default: { - break; - } - }; - return std::make_tuple(costs, gradients); + return std::make_tuple(RAIIATH(costs), RAIIATH(gradients)); +} + +void boxed_compute(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + RAIIATH t1(to(stack[0])); + RAIIATH t2(to(stack[1])); + RAIIATH t3(to(stack[2])); + RAIIATH t4(to(stack[3])); + int64_t blank = to(stack[4]); + double clamp = to(stack[5]); + bool fused_log_softmax = to(stack[6]); + auto result = compute( + std::move(t1), std::move(t2), std::move(t3), std::move(t4), + blank, clamp, fused_log_softmax); + stack[0] = from((std::get<0>(result)).release()); + stack[1] = from((std::get<1>(result)).release()); } -TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { - m.impl("rnnt_loss", &compute); +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { + m.impl("torchaudio::rnnt_loss", &boxed_compute); } } // namespace gpu diff --git a/src/libtorchaudio/rnnt/gpu/compute_alphas.cu b/src/libtorchaudio/rnnt/gpu/compute_alphas.cu index bde40daa9f..90e421ab4a 100644 --- a/src/libtorchaudio/rnnt/gpu/compute_alphas.cu +++ b/src/libtorchaudio/rnnt/gpu/compute_alphas.cu @@ -1,71 +1,125 @@ #include #include -#include +#include +#include +#include namespace torchaudio { namespace rnnt { namespace gpu { -torch::Tensor compute_alphas( - const torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, +using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle; + +RAIIATH compute_alphas( + const RAIIATH logits, + const RAIIATH targets, + const RAIIATH logit_lengths, + const RAIIATH target_lengths, int64_t blank, double clamp) { Options options; - options.batchSize_ = logit_lengths.size(0); - options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); - options.maxSrcLen_ = logits.size(1); - options.maxTgtLen_ = logits.size(2); - options.numTargets_ = logits.size(3); + int64_t tmp; + aoti_torch_get_size(logit_lengths.get(), 0, &tmp); + options.batchSize_ = (int)tmp; + aoti_torch_get_size(target_lengths.get(), 0, &tmp); + options.nHypos_ = (int)tmp; + options.nHypos_ /= options.batchSize_; + aoti_torch_get_size(logits.get(), 1, &tmp); + options.maxSrcLen_ = (int)tmp; + aoti_torch_get_size(logits.get(), 2, &tmp); + options.maxTgtLen_ = (int)tmp; + aoti_torch_get_size(logits.get(), 3, &tmp); + options.numTargets_ = (int)tmp; options.blank_ = blank; options.clamp_ = clamp; - TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA); - options.stream_ = at::cuda::getCurrentCUDAStream(); - cudaSetDevice(logits.get_device()); + int32_t logits_device_type; + aoti_torch_get_device_type(logits.get(), &logits_device_type); + AOTI_TORCH_CHECK(logits_device_type == aoti_torch_device_type_cuda()); + + int32_t logits_device; + aoti_torch_get_device_type(logits.get(), &logits_device); + int32_t logits_device_index; + aoti_torch_get_device_index(logits.get(), &logits_device_index); + int32_t logits_dtype; + aoti_torch_get_dtype(logits.get(), &logits_dtype); + + aoti_torch_get_current_cuda_stream(logits_device_index, (void**)&options.stream_); + cudaSetDevice(logits_device); options.device_ = GPU; - torch::Tensor alphas = torch::zeros( - {options.batchSize_ * options.nHypos_, - options.maxSrcLen_, - options.maxTgtLen_}, - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + int64_t param_sizes[3] = {options.batchSize_ * options.nHypos_, options.maxSrcLen_, options.maxTgtLen_}; + int64_t param_strides[3] = {options.maxSrcLen_ * options.maxTgtLen_, options.maxTgtLen_, 1}; - torch::Tensor int_workspace = torch::empty( - IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Int)); + AtenTensorHandle alphas; + aoti_torch_empty_strided(3, param_sizes, param_strides, logits_dtype, logits_device, logits_device_index, &alphas); + aoti_torch_zero_(alphas); - torch::Tensor float_workspace = torch::empty( - DtypeWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Float)); + AtenTensorHandle int_workspace; + int64_t sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; + int64_t strides[1] = {1}; + aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); + + AtenTensorHandle float_workspace; + aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace); + + int64_t float_numel; + aoti_torch_get_numel(float_workspace, &float_numel); + void *int_workspace_ptr; + aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr); + void *float_workspace_ptr; + aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr); + int64_t int_numel; + aoti_torch_get_numel(int_workspace, &int_numel); Workspace workspace( /*options=*/options, - /*dtype_data=*/float_workspace.data_ptr(), - /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data_ptr(), - /*int_size=*/int_workspace.numel()); + /*dtype_data=*/(float*)float_workspace_ptr, + /*dtype_size=*/float_numel, + /*int_data=*/(int*)int_workspace_ptr, + /*int_size=*/int_numel); + + void *logit_ptr; + aoti_torch_get_data_ptr(logits.get(), &logit_ptr); + + void *target_ptr; + aoti_torch_get_data_ptr(targets.get(), &target_ptr); + + void *logit_len_ptr; + aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr); + + void *target_len_ptr; + aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr); + + void *alpha_ptr; + aoti_torch_get_data_ptr(alphas, &alpha_ptr); // Only support float, this is mainly to enable easy // unit-testing ComputeAlphas( /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*logit_lengths=*/logit_lengths.data_ptr(), - /*target_lengths=*/target_lengths.data_ptr(), - /*alphas=*/alphas.data_ptr()); - return alphas; + /*logits=*/(float*)logit_ptr, + /*targets=*/(int*)target_ptr, + /*logit_lengths=*/(int*)logit_len_ptr, + /*target_lengths=*/(int*)target_len_ptr, + /*alphas=*/(float*)alpha_ptr); + return RAIIATH(alphas); +} + +void boxed_compute_alphas(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + RAIIATH t1(to(stack[0])); + RAIIATH t2(to(stack[1])); + RAIIATH t3(to(stack[2])); + RAIIATH t4(to(stack[3])); + int64_t blank = to(stack[4]); + double clamp = to(stack[5]); + RAIIATH result = compute_alphas(std::move(t1), std::move(t2), std::move(t3), std::move(t4), + blank, clamp); + stack[0] = from(result.release()); } -TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { - m.impl("rnnt_loss_alphas", &compute_alphas); +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { + m.impl("rnnt_loss_alphas", &boxed_compute_alphas); } } // namespace gpu diff --git a/src/libtorchaudio/rnnt/gpu/compute_betas.cu b/src/libtorchaudio/rnnt/gpu/compute_betas.cu index 18857c4388..7bed017b14 100644 --- a/src/libtorchaudio/rnnt/gpu/compute_betas.cu +++ b/src/libtorchaudio/rnnt/gpu/compute_betas.cu @@ -1,76 +1,133 @@ #include #include -#include +#include +#include +#include namespace torchaudio { namespace rnnt { namespace gpu { -torch::Tensor compute_betas( - const torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, +using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle; + + +RAIIATH compute_betas( + const RAIIATH logits, + const RAIIATH targets, + const RAIIATH logit_lengths, + const RAIIATH target_lengths, int64_t blank, double clamp) { - Options options; - options.batchSize_ = logit_lengths.size(0); - options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); - options.maxSrcLen_ = logits.size(1); - options.maxTgtLen_ = logits.size(2); - options.numTargets_ = logits.size(3); - options.blank_ = blank; - options.clamp_ = clamp; - - TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA); - options.stream_ = at::cuda::getCurrentCUDAStream(); - cudaSetDevice(logits.get_device()); + Options options; + int64_t tmp; + aoti_torch_get_size(logit_lengths.get(), 0, &tmp); + options.batchSize_ = (int)tmp; + aoti_torch_get_size(target_lengths.get(), 0, &tmp); + options.nHypos_ = (int)tmp; + options.nHypos_ /= options.batchSize_; + aoti_torch_get_size(logits.get(), 1, &tmp); + options.maxSrcLen_ = (int)tmp; + aoti_torch_get_size(logits.get(), 2, &tmp); + options.maxTgtLen_ = (int)tmp; + aoti_torch_get_size(logits.get(), 3, &tmp); + options.numTargets_ = (int)tmp; + options.blank_ = blank; + options.clamp_ = clamp; + + int32_t logits_device_type; + aoti_torch_get_device_type(logits.get(), &logits_device_type); + AOTI_TORCH_CHECK(logits_device_type == aoti_torch_device_type_cuda()); + + + int32_t logits_device; + aoti_torch_get_device_type(logits.get(), &logits_device); + int32_t logits_device_index; + aoti_torch_get_device_index(logits.get(), &logits_device_index); + int32_t logits_dtype; + aoti_torch_get_dtype(logits.get(), &logits_dtype); + + aoti_torch_get_current_cuda_stream(logits_device_index, (void**)&options.stream_); + cudaSetDevice(logits_device); options.device_ = GPU; - torch::Tensor costs = torch::empty( - target_lengths.size(0), - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + int64_t cost_sizes[1] = {options.batchSize_}; + int64_t stride1[1] = {1}; + AtenTensorHandle costs; + aoti_torch_empty_strided(1, cost_sizes, stride1, logits_dtype, logits_device, logits_device_index, &costs); - torch::Tensor betas = torch::zeros( - {options.batchSize_ * options.nHypos_, - options.maxSrcLen_, - options.maxTgtLen_}, - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + int64_t betas_sizes[3] = {options.batchSize_ * options.nHypos_, options.maxSrcLen_, options.maxTgtLen_}; + int64_t betas_strides[3] = {options.maxSrcLen_ * options.maxTgtLen_, options.maxTgtLen_, 1}; + AtenTensorHandle betas; + aoti_torch_empty_strided(3, betas_sizes, betas_strides, logits_dtype, logits_device, logits_device_index, &betas); - torch::Tensor int_workspace = torch::empty( - IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Int)); + AtenTensorHandle int_workspace; + int64_t w_sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; + aoti_torch_empty_strided(1, w_sizes, stride1, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); - torch::Tensor float_workspace = torch::empty( - DtypeWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Float)); + AtenTensorHandle float_workspace; + aoti_torch_empty_strided(1, w_sizes, stride1, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace); + + int64_t float_numel; + aoti_torch_get_numel(float_workspace, &float_numel); + void *int_workspace_ptr; + aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr); + void *float_workspace_ptr; + aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr); + int64_t int_numel; + aoti_torch_get_numel(int_workspace, &int_numel); Workspace workspace( /*options=*/options, - /*dtype_data=*/float_workspace.data_ptr(), - /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data_ptr(), - /*int_size=*/int_workspace.numel()); + /*dtype_data=*/(float*)float_workspace_ptr, + /*dtype_size=*/float_numel, + /*int_data=*/(int*)int_workspace_ptr, + /*int_size=*/int_numel); + + void *logit_ptr; + aoti_torch_get_data_ptr(logits.get(), &logit_ptr); + + void *target_ptr; + aoti_torch_get_data_ptr(targets.get(), &target_ptr); + + void *logit_len_ptr; + aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr); + + void *target_len_ptr; + aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr); + + void *beta_ptr; + aoti_torch_get_data_ptr(betas, &beta_ptr); + + void *cost_ptr; + aoti_torch_get_data_ptr(costs, &cost_ptr); // Only support float, this is mainly to enable easy // unit-testing ComputeBetas( /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*logit_lengths=*/logit_lengths.data_ptr(), - /*target_lengths=*/target_lengths.data_ptr(), - /*costs=*/costs.data_ptr(), - /*betas=*/betas.data_ptr()); - return betas; + /*logits=*/(float*)logit_ptr, + /*targets=*/(int*)target_ptr, + /*logit_lengths=*/(int*)logit_len_ptr, + /*target_lengths=*/(int*)target_len_ptr, + /*costs=*/(float*)cost_ptr, + /*betas=*/(float*)beta_ptr); + return RAIIATH(betas); +} + +void boxed_compute_betas(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + RAIIATH t1(to(stack[0])); + RAIIATH t2(to(stack[1])); + RAIIATH t3(to(stack[2])); + RAIIATH t4(to(stack[3])); + int64_t blank = to(stack[4]); + double clamp = to(stack[5]); + RAIIATH result = compute_betas(std::move(t1), std::move(t2), std::move(t3), std::move(t4), + blank, clamp); + stack[0] = from(result.release()); } -TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { - m.impl("rnnt_loss_betas", &compute_betas); +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { + m.impl("rnnt_loss_betas", &boxed_compute_betas); } } // namespace gpu diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index 42dde06814..b278d96bd4 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -1760,6 +1760,19 @@ def _fix_waveform_shape( waveform_shift = waveform_shift.view(shape[:-1] + waveform_shift.shape[-1:]) return waveform_shift +class RnntLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, *args): + output, saved = torch.ops.torchaudio.rnnt_loss.default(*args) + ctx.save_for_backward(saved) + return output + + @staticmethod + def backward(ctx, dy): + grad = ctx.saved_tensors[0] + grad_out = dy.view((-1, 1, 1, 1)) + result = grad * grad_out; + return (result, None, None, None, None, None, None, None) def _rnnt_loss( logits: Tensor, @@ -1803,14 +1816,14 @@ def _rnnt_loss( if blank < 0: # reinterpret blank index if blank < 0. blank = logits.shape[-1] + blank - costs, _ = torch.ops.torchaudio.rnnt_loss( - logits=logits, - targets=targets, - logit_lengths=logit_lengths, - target_lengths=target_lengths, - blank=blank, - clamp=clamp, - fused_log_softmax=fused_log_softmax, + costs = RnntLoss.apply( + logits, + targets, + logit_lengths, + target_lengths, + blank, + clamp, + fused_log_softmax ) if reduction == "mean":