Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 92 additions & 62 deletions src/libtorchaudio/forced_align/cpu/compute.cpp
Original file line number Diff line number Diff line change
@@ -1,26 +1,24 @@
#include <torch/script.h>
#include <torch/torch.h>
#include <libtorchaudio/utils.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>

using namespace std;

namespace torchaudio {
namespace alignment {
namespace cpu {

using torch::stable::Tensor;
using torch::headeronly::ScalarType;

// Inspired from
// https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp
template <typename scalar_t, at::ScalarType target_scalar_type>
template <typename scalar_t, ScalarType target_scalar_type>
void forced_align_impl(
const torch::Tensor& logProbs,
const torch::Tensor& targets,
const Tensor& logProbs,
const Tensor& targets,
const int64_t blank,
torch::Tensor& paths) {
Tensor& paths) {
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
using target_t = typename std::
conditional<target_scalar_type == torch::kInt, int, int64_t>::type;
conditional<target_scalar_type == ScalarType::Int, int, int64_t>::type;
const auto batchIndex =
0; // TODO: support batch version and use the real batch index
const auto T = logProbs.size(1);
Expand Down Expand Up @@ -132,79 +130,111 @@ void forced_align_impl(
delete[] backPtr_a;
}

std::tuple<torch::Tensor, torch::Tensor> compute(
const torch::Tensor& logProbs,
const torch::Tensor& targets,
const torch::Tensor& inputLengths,
const torch::Tensor& targetLengths,
std::tuple<Tensor, Tensor> compute(
const Tensor& logProbs,
const Tensor& targets,
const Tensor& inputLengths,
const Tensor& targetLengths,
const int64_t blank) {
TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor");
TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor");
TORCH_CHECK(
logProbs.device() == targets.device(),
"log_probs and targets need to be on the same device");
TORCH_CHECK(
logProbs.dtype() == torch::kFloat64 ||
logProbs.dtype() == torch::kFloat32 ||
logProbs.dtype() == torch::kFloat16,
STD_TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor");
STD_TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor");
STD_TORCH_CHECK(inputLengths.is_cpu(), "input_lengths must be a CPU tensor");
STD_TORCH_CHECK(targetLengths.is_cpu(), "target_lengths must be a CPU tensor");
STD_TORCH_CHECK(
logProbs.scalar_type() == ScalarType::Double ||
logProbs.scalar_type() == ScalarType::Float ||
logProbs.scalar_type() == ScalarType::Half,
"log_probs must be float64, float32 or float16 (half) type");
TORCH_CHECK(
targets.dtype() == torch::kInt32 || targets.dtype() == torch::kInt64,
STD_TORCH_CHECK(
targets.scalar_type() == ScalarType::Int || targets.scalar_type() == ScalarType::Long,
"targets must be int32 or int64 type");
TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous");
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
TORCH_CHECK(
STD_TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous");
STD_TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
STD_TORCH_CHECK(
logProbs.dim() == 3,
"log_probs must be 3-D (batch_size, input length, num classes)");
TORCH_CHECK(
STD_TORCH_CHECK(
targets.dim() == 2, "targets must be 2-D (batch_size, target length,)");
TORCH_CHECK(
STD_TORCH_CHECK(
inputLengths.dim() == 1, "input_lengths must be 1-D (batch_size,)");
TORCH_CHECK(
STD_TORCH_CHECK(
targetLengths.dim() == 1, "target_lengths must be 1-D (batch_size,)");
TORCH_CHECK(
STD_TORCH_CHECK(
logProbs.size(0) == 1,
"The batch dimension for log_probs must be 1 at the current version.")
TORCH_CHECK(
STD_TORCH_CHECK(
targets.size(0) == 1,
"The batch dimension for targets must be 1 at the current version.")
TORCH_CHECK(
STD_TORCH_CHECK(
blank >= 0 && blank < logProbs.size(-1),
"blank must be within [0, num classes)");

TORCH_CHECK(
logProbs.size(1) == at::max(inputLengths).item().toInt(),
STD_TORCH_CHECK(
logProbs.size(1) == torchaudio::util::max<int>(inputLengths),
"input length mismatch");
TORCH_CHECK(
targets.size(1) == at::max(targetLengths).item().toInt(),
STD_TORCH_CHECK(
targets.size(1) == torchaudio::util::max<int>(targetLengths),
"target length mismatch");

const auto B = logProbs.size(0);
const auto T = logProbs.size(1);
auto paths = torch::zeros(
{B, T},
torch::TensorOptions().device(targets.device()).dtype(targets.dtype()));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
logProbs.scalar_type(), "forced_align_impl", [&] {
if (targets.scalar_type() == torch::kInt64) {
forced_align_impl<scalar_t, torch::kInt64>(
logProbs, targets, blank, paths);
} else {
forced_align_impl<scalar_t, torch::kInt32>(
logProbs, targets, blank, paths);
}
});
return std::make_tuple(
paths,
logProbs
);
}

Tensor paths = torch::stable::new_empty(targets, {B, T});
torch::stable::zero_(paths);

switch (logProbs.scalar_type()) {
case ScalarType::Double: {
if (targets.scalar_type() == ScalarType::Long) {
forced_align_impl<double, ScalarType::Long>(logProbs, targets, blank, paths);
} else if (targets.scalar_type() == ScalarType::Int) {
forced_align_impl<double, ScalarType::Int>(logProbs, targets, blank, paths);
} else {
STD_TORCH_CHECK(false, "unreachable");
}
break;
}
case ScalarType::Float: {
if (targets.scalar_type() == ScalarType::Long) {
forced_align_impl<float, ScalarType::Long>(logProbs, targets, blank, paths);
} else if (targets.scalar_type() == ScalarType::Int) {
forced_align_impl<float, ScalarType::Int>(logProbs, targets, blank, paths);
} else {
STD_TORCH_CHECK(false, "unreachable");
}
break;
}
case ScalarType::Half: {
if (targets.scalar_type() == ScalarType::Long) {
forced_align_impl<c10::Half, ScalarType::Long>(logProbs, targets, blank, paths);
} else if (targets.scalar_type() == ScalarType::Int) {
forced_align_impl<c10::Half, ScalarType::Int>(logProbs, targets, blank, paths);
} else {
STD_TORCH_CHECK(false, "unreachable");
}
break;
}
default: {
STD_TORCH_CHECK(false, "unreachable");
}
};

return std::make_tuple(paths, logProbs);
}

void boxed_forced_align_cpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
STD_TORCH_CHECK(num_args == 5, "num_args must be 5");
STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2");
std::tuple<Tensor, Tensor> res = compute(
/*logProbs*/to<Tensor>(stack[0]),
/*targets*/to<Tensor>(stack[1]),
/*logit_lengths*/to<Tensor>(stack[2]),
/*target_lengths*/to<Tensor>(stack[3]),
/*blank*/float(to<int64_t>(stack[4])));
stack[0] = from(std::get<0>(res));
stack[1] = from(std::get<1>(res));
}

TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("forced_align", &compute);
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("forced_align", &boxed_forced_align_cpu);
}

} // namespace cpu
Expand Down
Loading
Loading