This is a simplified version of https://github.com/parlance/ctcdecode.
For the detail of the ctcdecode, please checkout https://github.com/parlance/ctcdecode .
The main difference is;
- Remove KenLM support
- Remove dependencies
- Boost
- utf8
- ThreadPool
- Clean-up library
- Use TorchScript for bind the C++
- Use Torch's
at::parallel_forin place ofThreadPool. - Remove unused functions
- Rename the module and decoder class (
simple_ctc.BeamSearchDecoder) - Moved the original decode method to
forwardand replacedecodemethod with high level API that performs label conversion as well.
- Add language model back.
- Fix timestep bug.
- OpenFST (statically built/linked when installing)
- PyTorch
See requirements.txt for the Python package requirements.
NOTE The build process downloads and compiles OpenFST, so it takes a while.
pip install git+https://github.com/mthrok/ctcdecode
For development
git clone https://github.com/mthrok/ctcdecode
cd ctcdecode
python setup.py develop
NOTE Currently, timesteps is not correctly computed. so an empty list is returned.
from simple_ctc import BeamSearchDecoder
decoder = BeamSearchDecoder(
labels,
beam_size=100,
cutoff_top_n=40,
cutoff_prob=1.0,
blank_id=0,
is_nll=False,
num_processes=4,
)
result = decoder.decode(prob_seqs, seq_lens)
print(result.labels[batch][beam][:]) # Resulting label sequences. 3D list.
print(result.scores[batch][beam]) # Scores of the sequences. 2D list.
print(result.timesteps[batch][beam][:]) # Timesteps of each label peak probabilities. 3D list.This decoder supports TorchScript. You should be able to deploy the dumped object in non-Python environment by loading the libctcdecode.so in your application.
import torch
path = 'decoder.zip'
torch.jit.save(decoder, path)
decoder = torch.jit.load(path)
result = decoder.decoder(prob_seqs, seq_lens)