From cec3363bf1739a2ceff72256890770c66e2e3db3 Mon Sep 17 00:00:00 2001 From: Larissa Date: Mon, 7 Apr 2025 14:18:55 +0200 Subject: [PATCH 01/11] Add TreeTimesyncBeamSearch --- .../makefiles/Modules.make | 1 + .../makefiles/Modules.make | 1 + .../makefiles/Modules.make | 1 + .../makefiles/Modules.make | 1 + .../makefiles/Modules.make | 1 + src/Search/Makefile | 4 + src/Search/Module.cc | 5 + src/Search/Module.hh | 3 +- src/Search/TreeBuilder.cc | 2 +- src/Search/TreeTimesyncBeamSearch/Makefile | 24 + .../TreeTimesyncBeamSearch.cc | 708 ++++++++++++++++++ .../TreeTimesyncBeamSearch.hh | 215 ++++++ 12 files changed, 964 insertions(+), 2 deletions(-) create mode 100644 src/Search/TreeTimesyncBeamSearch/Makefile create mode 100644 src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc create mode 100644 src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh diff --git a/apptainer/2022-10-21_tensorflow-1.15_arm_v1/makefiles/Modules.make b/apptainer/2022-10-21_tensorflow-1.15_arm_v1/makefiles/Modules.make index dbd334062..7884cc931 100644 --- a/apptainer/2022-10-21_tensorflow-1.15_arm_v1/makefiles/Modules.make +++ b/apptainer/2022-10-21_tensorflow-1.15_arm_v1/makefiles/Modules.make @@ -143,6 +143,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) LIBS_SEARCH += src/Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) +LIBS_SEARCH += src/Search/TreeTimesyncBeamSearch/libSprintTreeTimesyncBeamSearch.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/apptainer/2022-10-21_tensorflow-1.15_v1/makefiles/Modules.make b/apptainer/2022-10-21_tensorflow-1.15_v1/makefiles/Modules.make index dbd334062..7884cc931 100644 --- a/apptainer/2022-10-21_tensorflow-1.15_v1/makefiles/Modules.make +++ b/apptainer/2022-10-21_tensorflow-1.15_v1/makefiles/Modules.make @@ -143,6 +143,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) LIBS_SEARCH += src/Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) +LIBS_SEARCH += src/Search/TreeTimesyncBeamSearch/libSprintTreeTimesyncBeamSearch.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/apptainer/2023-05-08_tensorflow-2.8_v1/makefiles/Modules.make b/apptainer/2023-05-08_tensorflow-2.8_v1/makefiles/Modules.make index e8db4f150..cc9a21817 100644 --- a/apptainer/2023-05-08_tensorflow-2.8_v1/makefiles/Modules.make +++ b/apptainer/2023-05-08_tensorflow-2.8_v1/makefiles/Modules.make @@ -143,6 +143,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) LIBS_SEARCH += src/Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) +LIBS_SEARCH += src/Search/TreeTimesyncBeamSearch/libSprintTreeTimesyncBeamSearch.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/apptainer/2023-08-09_tensorflow-2.8_onnx-1.15_v1/makefiles/Modules.make b/apptainer/2023-08-09_tensorflow-2.8_onnx-1.15_v1/makefiles/Modules.make index 906530d1d..34d8dc50a 100644 --- a/apptainer/2023-08-09_tensorflow-2.8_onnx-1.15_v1/makefiles/Modules.make +++ b/apptainer/2023-08-09_tensorflow-2.8_onnx-1.15_v1/makefiles/Modules.make @@ -147,6 +147,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) LIBS_SEARCH += src/Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) +LIBS_SEARCH += src/Search/TreeTimesyncBeamSearch/libSprintTreeTimesyncBeamSearch.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/apptainer/2023-11-08_tensorflow-2.14_v1/makefiles/Modules.make b/apptainer/2023-11-08_tensorflow-2.14_v1/makefiles/Modules.make index b4b900e0b..b78f99efa 100644 --- a/apptainer/2023-11-08_tensorflow-2.14_v1/makefiles/Modules.make +++ b/apptainer/2023-11-08_tensorflow-2.14_v1/makefiles/Modules.make @@ -148,6 +148,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) LIBS_SEARCH += src/Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) +LIBS_SEARCH += src/Search/TreeTimesyncBeamSearch/libSprintTreeTimesyncBeamSearch.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/src/Search/Makefile b/src/Search/Makefile index 69fcc1a74..48f84be78 100644 --- a/src/Search/Makefile +++ b/src/Search/Makefile @@ -37,6 +37,7 @@ LIBSPRINTSEARCH_O += $(OBJDIR)/MinimumBayesRiskNBestListSearch.o LIBSPRINTSEARCH_O += $(OBJDIR)/MinimumBayesRiskSearchUtil.o endif SUBDIRS += LexiconfreeTimesyncBeamSearch +SUBDIRS += TreeTimesyncBeamSearch ifdef MODULE_SEARCH_WFST SUBDIRS += Wfst endif @@ -70,6 +71,9 @@ AdvancedTreeSearch: LexiconfreeTimesyncBeamSearch: $(MAKE) -C $@ libSprintLexiconfreeTimesyncBeamSearch.$(a) +TreeTimesyncBeamSearch: + $(MAKE) -C $@ libSprintTreeTimesyncBeamSearch.$(a) + include $(TOPDIR)/Rules.make sinclude $(LIBSPRINTSEARCH_O:.o=.d) diff --git a/src/Search/Module.cc b/src/Search/Module.cc index 7910f5950..ad2553525 100644 --- a/src/Search/Module.cc +++ b/src/Search/Module.cc @@ -17,6 +17,7 @@ #include #include #include "LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh" +#include "TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh" #include "TreeBuilder.hh" #ifdef MODULE_SEARCH_WFST #include @@ -36,6 +37,7 @@ Module_::Module_() { const Core::Choice Module_::searchTypeV2Choice( "lexiconfree-timesync-beam-search", SearchTypeV2::LexiconfreeTimesyncBeamSearchType, + "tree-timesync-beam-search", SearchTypeV2::TreeTimesyncBeamSearchType, Core::Choice::endMark()); const Core::ParameterChoice Module_::searchTypeV2Param( @@ -115,6 +117,9 @@ SearchAlgorithmV2* Module_::createSearchAlgorithmV2(const Core::Configuration& c case LexiconfreeTimesyncBeamSearchType: searchAlgorithm = new Search::LexiconfreeTimesyncBeamSearch(config); break; + case TreeTimesyncBeamSearchType: + searchAlgorithm = new Search::TreeTimesyncBeamSearch(config); + break; default: Core::Application::us()->criticalError("Unknown search algorithm type: %d", searchTypeV2Param(config)); break; diff --git a/src/Search/Module.hh b/src/Search/Module.hh index 72256811f..1d8930366 100644 --- a/src/Search/Module.hh +++ b/src/Search/Module.hh @@ -42,7 +42,8 @@ enum SearchType { }; enum SearchTypeV2 { - LexiconfreeTimesyncBeamSearchType + LexiconfreeTimesyncBeamSearchType, + TreeTimesyncBeamSearchType }; class Module_ { diff --git a/src/Search/TreeBuilder.cc b/src/Search/TreeBuilder.cc index ebda597c2..62175bccf 100644 --- a/src/Search/TreeBuilder.cc +++ b/src/Search/TreeBuilder.cc @@ -1369,7 +1369,7 @@ StateId CtcTreeBuilder::extendPronunciation(StateId startState, Bliss::Pronuncia // Add new (non-blank) state currentState = extendState(currentState, desc); - if (labelLoop_) { + if (labelLoop_ and not allophoneIsBlank) { // Add loop for this state addTransition(currentState, currentState); } diff --git a/src/Search/TreeTimesyncBeamSearch/Makefile b/src/Search/TreeTimesyncBeamSearch/Makefile new file mode 100644 index 000000000..9a7fad358 --- /dev/null +++ b/src/Search/TreeTimesyncBeamSearch/Makefile @@ -0,0 +1,24 @@ +#!gmake + +TOPDIR = ../../.. + +include $(TOPDIR)/Makefile.cfg + +# ----------------------------------------------------------------------------- + +SUBDIRS = +TARGETS = libSprintTreeTimesyncBeamSearch.$(a) + +LIBSPRINTTREETIMESYNCBEAMSEARCH_O = $(OBJDIR)/TreeTimesyncBeamSearch.o + + +# ----------------------------------------------------------------------------- + +all: $(TARGETS) + +libSprintTreeTimesyncBeamSearch.$(a): $(LIBSPRINTTREETIMESYNCBEAMSEARCH_O) + $(MAKELIB) $@ $^ + +include $(TOPDIR)/Rules.make + +sinclude $(LIBSPRINTTREETIMESYNCBEAMSEARCH_O:.o=.d) diff --git a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc new file mode 100644 index 000000000..10cdd589e --- /dev/null +++ b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc @@ -0,0 +1,708 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TreeTimesyncBeamSearch.hh" + +#include +#include + +#include +#include +#include +#include +#include "Search/Traceback.hh" +#include "Search/Module.hh" + +namespace Search { + +/* + * ======================= + * === LabelHypothesis === + * ======================= + */ + +TreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis() + : scoringContext(), + currentToken(Core::Type::max), + currentState(invalidTreeNodeIndex), + lmHistory(), + score(0.0), + trace(Core::ref(new LatticeTrace(0, {0, 0}, {}))) {} + +TreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( + TreeTimesyncBeamSearch::LabelHypothesis const& base, + TreeTimesyncBeamSearch::ExtensionCandidate const& extension, + Nn::ScoringContextRef const& newScoringContext) + : scoringContext(newScoringContext), + currentToken(extension.nextToken), + currentState(extension.state), + lmHistory(extension.lmHistory), + score(extension.score), + trace() { + + switch (extension.transitionType) { + case Nn::LabelScorer::INITIAL_BLANK: + case Nn::LabelScorer::INITIAL_LABEL: + trace = Core::ref(new LatticeTrace( + base.trace, + extension.pron, + extension.timeframe + 1, + {extension.score - extension.lmScore, extension.lmScore}, + {})); + break; + + case Nn::LabelScorer::LABEL_TO_LABEL: + case Nn::LabelScorer::BLANK_TO_LABEL: + case Nn::LabelScorer::LABEL_TO_BLANK: + if (base.trace->pronunciation != nullptr) { // A word has ended before and the first token of a new word was predicted -> start a new trace + trace = Core::ref(new LatticeTrace( + base.trace, + extension.pron, + extension.timeframe + 1, + {base.trace->score.acoustic + (extension.score - base.score - extension.lmScore), base.trace->score.lm + extension.lmScore}, + {})); + } + else { // Word-end or within-word hypothesis and no word has ended before -> update the old trace + trace = Core::ref(new LatticeTrace(*base.trace)); + trace->sibling = {}; + trace->pronunciation = extension.pron; + trace->time = extension.timeframe + 1; + trace->score.acoustic = base.trace->score.acoustic + (extension.score - base.score - extension.lmScore); + trace->score.lm = base.trace->score.lm + extension.lmScore; + } + break; + + case Nn::LabelScorer::LABEL_LOOP: + case Nn::LabelScorer::BLANK_LOOP: + // Word-end or within-word hypothesis (cannot happen across words) -> update the old trace + trace = Core::ref(new LatticeTrace(*base.trace)); + trace->sibling = {}; + trace->pronunciation = extension.pron; + trace->time = extension.timeframe + 1; + trace->score.acoustic = base.trace->score.acoustic + (extension.score - base.score - extension.lmScore); + trace->score.lm = base.trace->score.lm + extension.lmScore; + break; + } +} + +std::string TreeTimesyncBeamSearch::LabelHypothesis::toString() const { + std::stringstream ss; + ss << "Score: " << score << ", current state: " << currentState << ", traceback: "; + + auto traceback = trace->performTraceback(); + + for (auto& item : *traceback) { + if (item.pronunciation and item.pronunciation->lemma()) { + ss << item.pronunciation->lemma()->symbol() << " "; + } + } + return ss.str(); +} + +/* + * ===================================== + * === TreeTimesyncBeamSearch === + * ===================================== + */ + +const Core::ParameterInt TreeTimesyncBeamSearch::paramMaxBeamSize( + "max-beam-size", + "Maximum number of within-word hypotheses in the search beam.", + 1, 1); + +const Core::ParameterInt TreeTimesyncBeamSearch::paramMaxWordEndBeamSize( + "max-word-end-beam-size", + "Maximum number of word-end hypotheses in the search beam. If not set, global beam pruning will be done and word-end hypotheses will not be pruned separately.", + std::numeric_limits::max(), 0); + +const Core::ParameterFloat TreeTimesyncBeamSearch::paramScoreThreshold( + "score-threshold", + "Prune any within-word hypothesis with a score that is at least this much worse than the best hypothesis.", + Core::Type::max, 0); + +const Core::ParameterFloat TreeTimesyncBeamSearch::paramWordEndScoreThreshold( + "word-end-score-threshold", + "Prune any word-end hypothesis with a score that is at least this much worse than the best word-end hypothesis. If not set, global score pruning will be done \ + and word-end hypotheses will not be pruned separately. If the value is below 1.0, e.g. 0.7, then it is relative to within-word score-pruning.", + Core::Type::max, 0); + +const Core::ParameterBool TreeTimesyncBeamSearch::paramCollapseRepeatedLabels( + "collapse-repeated-labels", + "Collapse repeated emission of the same label into one output. If false, every emission is treated like a new output.", + false); + +const Core::ParameterBool TreeTimesyncBeamSearch::paramForceBlankAcrossWords( + "force-blank-between-repeated-labels-across-words", + "Require a blank label between identical labels at word end and word begin.", + false); + +const Core::ParameterBool TreeTimesyncBeamSearch::paramSentenceEndFallBack( + "sentence-end-fall-back", + "Allow for fallback solution if no active word-end hypothesis exists at the end of a segment.", + true); + +const Core::ParameterBool TreeTimesyncBeamSearch::paramLogStepwiseStatistics( + "log-stepwise-statistics", + "Log statistics about the beam at every search step.", + false); + +TreeTimesyncBeamSearch::TreeTimesyncBeamSearch(Core::Configuration const& config) + : Core::Component(config), + SearchAlgorithmV2(config), + maxBeamSize_(paramMaxBeamSize(config)), + maxWordEndBeamSize_(paramMaxWordEndBeamSize(config)), + scoreThreshold_(paramScoreThreshold(config)), + wordEndScoreThreshold_(paramWordEndScoreThreshold(config)), + collapseRepeatedLabels_(paramCollapseRepeatedLabels(config)), + forceBlankAcrossWords_(paramForceBlankAcrossWords(config)), + sentenceEndFallback_(paramSentenceEndFallBack(config)), + logStepwiseStatistics_(paramLogStepwiseStatistics(config)), + debugChannel_(config, "debug"), + labelScorer_(), + beam_(), + extensions_(), + newBeam_(), + requests_(), + recombinedHypotheses_(), + initializationTime_(), + featureProcessingTime_(), + scoringTime_(), + contextExtensionTime_(), + numHypsAfterScorePruning_("num-hyps-after-score-pruning"), + numHypsAfterBeamPruning_("num-hyps-after-beam-pruning"), + numWordEndHypsAfterScorePruning_("num-word-end-hyps-after-score-pruning"), + numWordEndHypsAfterBeamPruning_("num-word-end-hyps-after-beam-pruning"), + numActiveHyps_("num-active-hyps"), + finishedSegment_(false) { + if (wordEndScoreThreshold_ <= 1.0) { + if (scoreThreshold_ == Core::Type::max) { + error() << "Word-end score-threshold relative to score-threshold, but score-threshold is not set"; + } + wordEndScoreThreshold_ *= scoreThreshold_; + } +} + +Speech::ModelCombination::Mode TreeTimesyncBeamSearch::requiredModelCombination() const { + return Speech::ModelCombination::useLabelScorer | Speech::ModelCombination::useLexicon | Speech::ModelCombination::useAcousticModel | Speech::ModelCombination::useLanguageModel; +} + +Speech::ModelCombination::Mode TreeTimesyncBeamSearch::requiredAcousticModel() const { + return Am::AcousticModel::noEmissions; +} + +bool TreeTimesyncBeamSearch::setModelCombination(Speech::ModelCombination const& modelCombination) { + lexicon_ = modelCombination.lexicon(); + labelScorer_ = modelCombination.labelScorer(); + acousticModel_ = modelCombination.acousticModel(); + languageModel_ = modelCombination.languageModel(); + + blankLabelIndex_ = acousticModel_->emissionIndex(acousticModel_->blankAllophoneStateIndex()); + + // Build the search tree + log() << "Start building search tree"; + network_ = Core::ref(new PersistentStateTree(config, acousticModel_, lexicon_, std::bind(&Module_::createTreeBuilder, &Search::Module::instance(), std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5))); + std::unique_ptr builder = Search::Module::instance().createTreeBuilder(config, *lexicon_, *acousticModel_, *network_); + builder->build(); + log() << "Building finished"; + + // Create look-ups for state successors and exits of each state + createSuccessorLookups(); + + // Pre-allocate vectors + + // If maxWordEndBeamSize_ is not set, we need the maximum number of exits a node can have for estimating the max. size of the vectors + int maxWordEnds = maxWordEndBeamSize_ == std::numeric_limits::max() ? maxNumberOfExits_ : maxWordEndBeamSize_; + + // The beam contains all within-word and word-end hypotheses which survived pruning + beam_.reserve(maxBeamSize_ + maxWordEnds); + newBeam_.reserve(maxBeamSize_ + maxWordEnds); + recombinedHypotheses_.reserve(maxBeamSize_ + maxWordEnds); + + // Each hypothesis in the beam can yield max. one extension per phoneme in the lexicon + extensions_.reserve((maxBeamSize_ + maxWordEnds) * lexicon_->phonemeInventory()->nPhonemes()); + requests_.reserve((maxBeamSize_ + maxWordEnds) * lexicon_->phonemeInventory()->nPhonemes()); + + // After pruning there are maxBeamSize_ state extensions, each can yield max. maxNumberOfExits_ word-end extensions + withinWordExtensions_.reserve(maxBeamSize_); + wordEndExtensions_.reserve(maxBeamSize_ * maxNumberOfExits_); + + reset(); + return true; +} + +void TreeTimesyncBeamSearch::reset() { + initializationTime_.start(); + + labelScorer_->reset(); + + // Reset beam to a single empty hypothesis + beam_.clear(); + beam_.push_back(LabelHypothesis()); + beam_.front().scoringContext = labelScorer_->getInitialScoringContext(); + beam_.front().currentState = network_->rootState; + beam_.front().lmHistory = languageModel_->startHistory(); + + finishedSegment_ = false; + + initializationTime_.stop(); +} + +void TreeTimesyncBeamSearch::enterSegment(Bliss::SpeechSegment const* segment) { + initializationTime_.start(); + labelScorer_->reset(); + resetStatistics(); + initializationTime_.stop(); + finishedSegment_ = false; +} + +void TreeTimesyncBeamSearch::finishSegment() { + featureProcessingTime_.start(); + labelScorer_->signalNoMoreFeatures(); + featureProcessingTime_.stop(); + decodeManySteps(); + logStatistics(); + finishedSegment_ = true; + finalizeLmScoring(); +} + +void TreeTimesyncBeamSearch::putFeature(std::shared_ptr const& data, size_t featureSize) { + featureProcessingTime_.start(); + labelScorer_->addInput(data, featureSize); + featureProcessingTime_.stop(); +} + +void TreeTimesyncBeamSearch::putFeature(std::vector const& data) { + featureProcessingTime_.start(); + labelScorer_->addInput(data); + featureProcessingTime_.stop(); +} + +void TreeTimesyncBeamSearch::putFeatures(std::shared_ptr const& data, size_t timeSize, size_t featureSize) { + featureProcessingTime_.start(); + labelScorer_->addInputs(data, timeSize, featureSize); + featureProcessingTime_.stop(); +} + +Core::Ref TreeTimesyncBeamSearch::getCurrentBestTraceback() const { + return getBestHypothesis().trace->performTraceback(); +} + +Core::Ref TreeTimesyncBeamSearch::getCurrentBestWordLattice() const { + auto& bestHypothesis = getBestHypothesis(); + LatticeTrace endTrace(bestHypothesis.trace, 0, bestHypothesis.trace->time + 1, bestHypothesis.trace->score, {}); + + for (size_t hypIdx = 1ul; hypIdx < beam_.size(); ++hypIdx) { + auto& hyp = beam_[hypIdx]; + auto siblingTrace = Core::ref(new LatticeTrace(hyp.trace, 0, hyp.trace->time, hyp.trace->score, {})); + endTrace.appendSiblingToChain(siblingTrace); + } + + return endTrace.buildWordLattice(lexicon_); +} + +bool TreeTimesyncBeamSearch::decodeStep() { + if (finishedSegment_) { + return false; + } + + /* + * Collect all possible within-word extensions for all hypotheses in the beam. + * Also create scoring requests for the label scorer. + * Each extension candidate makes up a request. + */ + extensions_.clear(); + requests_.clear(); + + for (size_t hypIndex = 0ul; hypIndex < beam_.size(); ++hypIndex) { + auto& hyp = beam_[hypIndex]; + + // Iterate over the successors of this hypothesis' current state in the tree + for (const auto& successorState: stateSuccessorLookup_[hyp.currentState]) { + Nn::LabelIndex tokenIdx = network_->structure.state(successorState).stateDesc.acousticModel; + // If we want to force blank between repeated labels across words, a new word should not start with the same token as the previous word ended (except for blank itself) + // If we don't force blank and we have a repeated label across words, we need to make sure to have label-to-Label as transition type + if (not (forceBlankAcrossWords_ and (hyp.currentState == network_->rootState) and (tokenIdx == hyp.currentToken) and (tokenIdx != blankLabelIndex_))) { + auto transitionType = inferTransitionType(hyp.currentToken, tokenIdx, hyp.currentState == network_->rootState); + extensions_.push_back( + {tokenIdx, + nullptr, + successorState, + hyp.lmHistory, + hyp.score, + 0.0, + 0, + transitionType, + hypIndex}); + requests_.push_back({beam_[hypIndex].scoringContext, tokenIdx, transitionType}); + } + } + } + + /* + * Perform scoring of all the requests with the label scorer. + */ + scoringTime_.start(); + auto result = labelScorer_->computeScoresWithTimes(requests_); + scoringTime_.stop(); + + if (not result) { + // LabelScorer could not compute scores -> no search step can be made. + return false; + } + + for (size_t requestIdx = 0ul; requestIdx < extensions_.size(); ++requestIdx) { + extensions_[requestIdx].score += result->scores[requestIdx]; + extensions_[requestIdx].timeframe = result->timeframes[requestIdx]; + } + + if (logStepwiseStatistics_) { + clog() << Core::XmlOpen("search-step-stats"); + } + + + /* + * Prune set of possible within-word extensions by max beam size and possibly also by score. + */ + scorePruning(extensions_, scoreThreshold_); + numHypsAfterScorePruning_ += extensions_.size(); + if (logStepwiseStatistics_) { + clog() << Core::XmlFull("num-hyps-after-score-pruning", extensions_.size()); + } + + beamSizePruning(extensions_, maxBeamSize_); + numHypsAfterBeamPruning_ += extensions_.size(); + if (logStepwiseStatistics_) { + clog() << Core::XmlFull("num-hyps-after-beam-pruning", extensions_.size()); + } + + /* + * Expand extensions to word-end hypotheses and incorporate the language model + */ + withinWordExtensions_.clear(); + wordEndExtensions_.clear(); + for (const auto& extension: extensions_) { + // If there is at least one state successor, keep it as within-word hypothesis + if (not stateSuccessorLookup_[extension.state].empty()) { + withinWordExtensions_.push_back(extension); + } + std::vector exitList = exitLookup_[extension.state]; + if (not exitList.empty()) { + // Create one word-end hypothesis for each exit + for (const auto& exit: exitList) { + ExtensionCandidate wordEndExtension(extension); + const Bliss::LemmaPronunciation* lemmaPron = lexicon_->lemmaPronunciation(exit.pronunciation); + const Bliss::Lemma* lemma = lemmaPron->lemma(); + + // Start from the root node (the exit's transit state) in the next step + wordEndExtension.state = exit.transitState; + wordEndExtension.pron = lemmaPron; + + if (lemma != lexicon_->specialLemma("blank")) { + const Bliss::SyntacticTokenSequence sts = lemma->syntacticTokenSequence(); + const Bliss::SyntacticToken* st = sts.front(); + + // Add the LM score and update the LM history + Lm::Score lmScore = languageModel_->score(wordEndExtension.lmHistory, st); + wordEndExtension.score += lmScore; + wordEndExtension.lmScore = lmScore; + wordEndExtension.lmHistory = languageModel_->extendedHistory(wordEndExtension.lmHistory, st); + } + wordEndExtensions_.push_back(wordEndExtension); + } + } + } + + /* + * Prune set of word-end hypotheses by max beam size and possibly also by score. + */ + scorePruning(wordEndExtensions_, wordEndScoreThreshold_); + numWordEndHypsAfterScorePruning_ += wordEndExtensions_.size(); + if (logStepwiseStatistics_) { + clog() << Core::XmlFull("num-word-end-hyps-after-score-pruning", wordEndExtensions_.size()); + } + + beamSizePruning(wordEndExtensions_, maxWordEndBeamSize_); + numWordEndHypsAfterBeamPruning_ += wordEndExtensions_.size(); + if (logStepwiseStatistics_) { + clog() << Core::XmlFull("num-word-end-hyps-after-beam-pruning", wordEndExtensions_.size()); + } + + /* + * Create new beam from surviving extensions. + */ + newBeam_.clear(); + extensions_.swap(withinWordExtensions_); + extensions_.insert(extensions_.end(), wordEndExtensions_.begin(), wordEndExtensions_.end()); + + for (auto const& extension: extensions_) { + auto const& baseHyp = beam_[extension.baseHypIndex]; + + auto newScoringContext = labelScorer_->extendedScoringContext( + {baseHyp.scoringContext, + extension.nextToken, + extension.transitionType}); + + newBeam_.push_back({baseHyp, extension, newScoringContext}); + } + + /* + * For all hypotheses at the same state and with the same scoring context and LM history + * keep only the best since they will all develop in the same way. + */ + recombination(newBeam_); + numActiveHyps_ += newBeam_.size(); + + if (logStepwiseStatistics_) { + clog() << Core::XmlFull("active-hyps", newBeam_.size()); + } + + if (debugChannel_.isOpen()) { + std::stringstream ss; + for (size_t hypIdx = 0ul; hypIdx < newBeam_.size(); ++hypIdx) { + ss << "Hypothesis " << hypIdx + 1ul << ": " << newBeam_[hypIdx].toString() << "\n"; + } + ss << "\n"; + debugChannel_ << ss.str(); + } + + beam_.swap(newBeam_); + + + if (logStepwiseStatistics_) { + clog() << Core::XmlFull("best-hyp-score", getBestHypothesis().score); + clog() << Core::XmlFull("worst-hyp-score", getWorstHypothesis().score); + clog() << Core::XmlClose("search-step-stats"); + } + + return true; +} + +TreeTimesyncBeamSearch::LabelHypothesis const& TreeTimesyncBeamSearch::getBestHypothesis() const { + verify(not beam_.empty()); + + return *std::min_element(beam_.begin(), beam_.end()); +} + +TreeTimesyncBeamSearch::LabelHypothesis const& TreeTimesyncBeamSearch::getWorstHypothesis() const { + verify(not beam_.empty()); + + return *std::max_element(beam_.begin(), beam_.end()); +} + +void TreeTimesyncBeamSearch::resetStatistics() { + initializationTime_.reset(); + featureProcessingTime_.reset(); + scoringTime_.reset(); + contextExtensionTime_.reset(); + numHypsAfterScorePruning_.clear(); + numHypsAfterBeamPruning_.clear(); + numWordEndHypsAfterScorePruning_.clear(); + numWordEndHypsAfterBeamPruning_.clear(); + numActiveHyps_.clear(); +} + +void TreeTimesyncBeamSearch::logStatistics() const { + clog() << Core::XmlOpen("timing-statistics") + Core::XmlAttribute("unit", "milliseconds"); + clog() << Core::XmlOpen("initialization-time") << initializationTime_.elapsedMilliseconds() << Core::XmlClose("initialization-time"); + clog() << Core::XmlOpen("feature-processing-time") << featureProcessingTime_.elapsedMilliseconds() << Core::XmlClose("feature-processing-time"); + clog() << Core::XmlOpen("scoring-time") << scoringTime_.elapsedMilliseconds() << Core::XmlClose("scoring-time"); + clog() << Core::XmlOpen("context-extension-time") << contextExtensionTime_.elapsedMilliseconds() << Core::XmlClose("context-extension-time"); + clog() << Core::XmlClose("timing-statistics"); + numHypsAfterScorePruning_.write(clog()); + numHypsAfterBeamPruning_.write(clog()); + numWordEndHypsAfterScorePruning_.write(clog()); + numWordEndHypsAfterBeamPruning_.write(clog()); + numActiveHyps_.write(clog()); +} + +Nn::LabelScorer::TransitionType TreeTimesyncBeamSearch::inferTransitionType(Nn::LabelIndex prevLabel, Nn::LabelIndex nextLabel, bool inRoot) const { + bool prevIsBlank = prevLabel == blankLabelIndex_; + bool nextIsBlank = nextLabel == blankLabelIndex_; + + if (prevLabel == Core::Type::max) { + if (nextIsBlank) { + return Nn::LabelScorer::TransitionType::INITIAL_BLANK; + } + else { + return Nn::LabelScorer::TransitionType::INITIAL_LABEL; + } + } + + if (prevIsBlank) { + if (nextIsBlank) { + return Nn::LabelScorer::TransitionType::BLANK_LOOP; + } + else { + return Nn::LabelScorer::TransitionType::BLANK_TO_LABEL; + } + } + else { + if (nextIsBlank) { + return Nn::LabelScorer::TransitionType::LABEL_TO_BLANK; + } + else if (collapseRepeatedLabels_ and prevLabel == nextLabel and not inRoot) { + return Nn::LabelScorer::TransitionType::LABEL_LOOP; + } + else { + return Nn::LabelScorer::TransitionType::LABEL_TO_LABEL; + } + } +} + +void TreeTimesyncBeamSearch::beamSizePruning(std::vector& extensions, size_t maxBeamSize) const { + if (extensions.size() <= maxBeamSize) { + return; + } + + // Sort the hypotheses by associated score value such that the first `maxBeamSize` elements are the best + std::nth_element(extensions.begin(), extensions.begin() + maxBeamSize, extensions.end()); + extensions.resize(maxBeamSize); // Get rid of excessive elements +} + +void TreeTimesyncBeamSearch::scorePruning(std::vector& extensions, Score scoreThreshold) const { + if (extensions.empty() or scoreThreshold == Core::Type::max) { + return; + } + + // Compute the pruning threshold + auto bestScore = std::min_element(extensions.begin(), extensions.end())->score; + auto pruningThreshold = bestScore + scoreThreshold; + + // Remove elements with score > pruningThreshold + extensions.erase( + std::remove_if( + extensions.begin(), + extensions.end(), + [=](auto const& ext) { return ext.score > pruningThreshold; }), + extensions.end()); +} + +void TreeTimesyncBeamSearch::recombination(std::vector& hypotheses) { + // Represents a unique combination of StateId, ScoringContext and LmHistory + struct RecombinationContext { + StateId state; + Nn::ScoringContextRef scoringContext; + Lm::History lmHistory; + + RecombinationContext(StateId state, Nn::ScoringContextRef scoringContext, Lm::History lmHistory) + : state(state), scoringContext(scoringContext), lmHistory(lmHistory) {} + + bool operator==(const RecombinationContext& other) const { + return state == other.state && Nn::ScoringContextEq{}(scoringContext, other.scoringContext) && lmHistory == other.lmHistory; + } + }; + struct RecombinationContextHash { + size_t operator()(const RecombinationContext& context) const { + size_t h1 = context.state; + size_t h2 = Nn::ScoringContextHash{}(context.scoringContext); + size_t h3 = Lm::History::Hash{}(context.lmHistory); + return h1 ^ (h2 << 1) ^ (h3 << 2); + } + }; + + recombinedHypotheses_.clear(); + // Map each unique combination of StateId, ScoringContext and LmHistory in newHypotheses to its hypothesis + std::unordered_map seenCombinations; + for (auto const& hyp : hypotheses) { + // Use try_emplace to check if the combination already exists and create a new entry if not at the same time + auto [it, inserted] = seenCombinations.try_emplace({hyp.currentState, hyp.scoringContext, hyp.lmHistory}, nullptr); + + if (inserted) { + // First time seeing this combination so move it over to `newHypotheses` + recombinedHypotheses_.push_back(std::move(hyp)); + it->second = &recombinedHypotheses_.back(); + } + else { + verify(not hyp.trace->sibling); + + auto* existingHyp = it->second; + if (hyp.score < existingHyp->score) { + // New hyp is better -> replace in `newHypotheses` and add existing one as sibling + hyp.trace->sibling = existingHyp->trace; + *existingHyp = std::move(hyp); // Overwrite in-place + } + else { + // New hyp is worse -> add to existing one as sibling + hyp.trace->sibling = existingHyp->trace->sibling; + existingHyp->trace->sibling = hyp.trace; + } + } + } + + hypotheses.swap(recombinedHypotheses_); +} + +void TreeTimesyncBeamSearch::createSuccessorLookups() { + stateSuccessorLookup_.reserve(network_->structure.stateCount()); + exitLookup_.reserve(network_->structure.stateCount()); + + for (u32 state = 1; state < network_->structure.stateCount(); ++state) { + std::vector stateList; // Collect the state successors of all nodes + std::vector exitList; // Collect the exits of all nodes + for (HMMStateNetwork::SuccessorIterator it = network_->structure.successors(state); it; ++it) { + if (not it.isLabel()) { + stateList.push_back(*it); + } else { + exitList.push_back(network_->exits[it.label()]); + } + } + stateSuccessorLookup_[state] = stateList; + exitLookup_[state] = exitList; + + // Retrieve the maximal number of exits a node in the tree can have to estimate the size of pre-allocated vectors + if (exitList.size() > maxNumberOfExits_) { + maxNumberOfExits_ = exitList.size(); + } + } +} + +void TreeTimesyncBeamSearch::finalizeLmScoring() { + newBeam_.clear(); + for (size_t hypIndex = 0ul; hypIndex < beam_.size(); ++hypIndex) { + auto& hyp = beam_[hypIndex]; + // Check if the hypotheses in the beam are at a root state and add the sentence-end LM score + if (hyp.currentState == network_->rootState or network_->otherRootStates.find(hyp.currentState) != network_->otherRootStates.end()) { + Lm::Score sentenceEndScore = languageModel_->sentenceEndScore(hyp.lmHistory); + hyp.score += sentenceEndScore; + hyp.trace->score.lm += sentenceEndScore; + newBeam_.push_back(hyp); + } + } + + if (newBeam_.empty()) { // There was no word-end hypothesis in the beam + warning("No active word-end hypothesis at segment end."); + if (sentenceEndFallback_) { + log() << "Use sentence-end fallback"; + // The trace of the unfinished word keeps an empty pronunciation, only the LM score is added + for (size_t hypIndex = 0ul; hypIndex < beam_.size(); ++hypIndex) { + auto& hyp = beam_[hypIndex]; + Lm::Score sentenceEndScore = languageModel_->sentenceEndScore(hyp.lmHistory); + hyp.score += sentenceEndScore; + hyp.trace->score.lm += sentenceEndScore; + newBeam_.push_back(hyp); + } + } + else { + // Construct an empty hypothesis with a lattice containing only one empty pronunciation from start to end + newBeam_.push_back(LabelHypothesis()); + newBeam_.front().trace->time = beam_.front().trace->time; // Retrieve the timeframe from any hyp in the old beam + newBeam_.front().trace->pronunciation = nullptr; + newBeam_.front().trace->predecessor = Core::ref(new LatticeTrace(0, {0, 0}, {})); + } + } + beam_.swap(newBeam_); +} + +} // namespace Search diff --git a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh new file mode 100644 index 000000000..c2be05881 --- /dev/null +++ b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh @@ -0,0 +1,215 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TREE_TIMESYNC_BEAM_SEARCH_HH +#define TREE_TIMESYNC_BEAM_SEARCH_HH + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace Search { + +/* + * Simple time synchronous beam search algorithm on a search tree built by the CtcTreeBuilder oder RnaTreeBuilder. + * At a word end, a language model score is added to the hypothesis score, + * if no language model should be used, the LM-scale has to be set to 0.0. + * Supports global or separate pruning of within-word and word-end hypotheses + * by max beam-size and by score difference to the best hypothesis. + * Uses a LabelScorer to context initialization/extension and scoring. + * + * The blank label index is retrieved from the lexicon to ensure consistency with the blank index used for the search tree. + * If the search tree contains label-loops, one will most likely want to set "collapse-repeated-labels" to true so + * the label loops are also considered when inferring the transtion type as scoring context. + * Similarly, if the search tree forces blank between two repeated labels (and if repeated labels are collapsed), + * blank should also be forced across words if the new word starts with the same label as the previous word ended, + * so "force-blank-between-repeated-labels-across-words" has to be set to true in this case. + */ +class TreeTimesyncBeamSearch : public SearchAlgorithmV2 { +protected: + /* + * Possible extension for some label hypothesis in the beam + */ + struct ExtensionCandidate { + Nn::LabelIndex nextToken; // Proposed token to extend the hypothesis with + const Bliss::LemmaPronunciation* pron; // Pronunciation of the lemma if we are at a word end + StateId state; // State in the search tree of this extension + Lm::History lmHistory; // LM history of the hypothesis, possibly extended at a word end + Score score; // Would-be total score of the full hypothesis after extension (incl. LM score) + Score lmScore; // Would-be LM score of a word-end hypothesis after extension + Search::TimeframeIndex timeframe; // Timestamp of `nextToken` for traceback + Nn::LabelScorer::TransitionType transitionType; // Type of transition toward `nextToken` + size_t baseHypIndex; // Index of base hypothesis in global beam + + bool operator<(ExtensionCandidate const& other) { + return score < other.score; + } + }; + + /* + * Struct containing all information about a single hypothesis in the beam + */ + struct LabelHypothesis { + Nn::ScoringContextRef scoringContext; // Context to compute scores based on this hypothesis + Nn::LabelIndex currentToken; // Most recent token in associated label sequence (useful to infer transition type) + StateId currentState; // Current state in the search tree + Lm::History lmHistory; // Language model history + Score score; // Full score of the hypothesis + Core::Ref trace; // Associated trace for traceback or lattice building of hypothesis + + LabelHypothesis(); + LabelHypothesis(LabelHypothesis const& base, ExtensionCandidate const& extension, Nn::ScoringContextRef const& newScoringContext); + + bool operator<(LabelHypothesis const& other) const { + return score < other.score; + } + + /* + * Get string representation for debugging + */ + std::string toString() const; + }; + +public: + static const Core::ParameterInt paramMaxBeamSize; + static const Core::ParameterInt paramMaxWordEndBeamSize; + static const Core::ParameterFloat paramScoreThreshold; + static const Core::ParameterFloat paramWordEndScoreThreshold; + static const Core::ParameterBool paramCollapseRepeatedLabels; + static const Core::ParameterBool paramForceBlankAcrossWords; + static const Core::ParameterBool paramSentenceEndFallBack; + static const Core::ParameterBool paramLogStepwiseStatistics; + + TreeTimesyncBeamSearch(Core::Configuration const&); + + // Inherited methods from `SearchAlgorithmV2` + + Speech::ModelCombination::Mode requiredModelCombination() const override; + Speech::ModelCombination::Mode requiredAcousticModel() const override; + bool setModelCombination(Speech::ModelCombination const& modelCombination) override; + void reset() override; + void enterSegment(Bliss::SpeechSegment const* = nullptr) override; + void finishSegment() override; + void putFeature(std::shared_ptr const& data, size_t featureSize) override; + void putFeature(std::vector const& data) override; + void putFeatures(std::shared_ptr const& data, size_t timeSize, size_t featureSize) override; + Core::Ref getCurrentBestTraceback() const override; + Core::Ref getCurrentBestWordLattice() const override; + bool decodeStep() override; + +private: + size_t maxBeamSize_; + size_t maxWordEndBeamSize_; + + Score scoreThreshold_; + Score wordEndScoreThreshold_; + + Nn::LabelIndex blankLabelIndex_; + + bool collapseRepeatedLabels_; + bool forceBlankAcrossWords_; + + bool sentenceEndFallback_; + + bool logStepwiseStatistics_; + + Core::Channel debugChannel_; + + Core::Ref labelScorer_; + Bliss::LexiconRef lexicon_; + Core::Ref network_; + Core::Ref acousticModel_; + Core::Ref languageModel_; + std::vector beam_; + + // Pre-allocated intermediate vectors + std::vector extensions_; + std::vector withinWordExtensions_; + std::vector wordEndExtensions_; + std::vector newBeam_; + std::vector requests_; + std::vector recombinedHypotheses_; + + int maxNumberOfExits_; + + std::vector> stateSuccessorLookup_; + std::vector> exitLookup_; + + Core::StopWatch initializationTime_; + Core::StopWatch featureProcessingTime_; + Core::StopWatch scoringTime_; + Core::StopWatch contextExtensionTime_; + + Core::Statistics numHypsAfterScorePruning_; + Core::Statistics numHypsAfterBeamPruning_; + Core::Statistics numWordEndHypsAfterScorePruning_; + Core::Statistics numWordEndHypsAfterBeamPruning_; + Core::Statistics numActiveHyps_; + + bool finishedSegment_; + + LabelHypothesis const& getBestHypothesis() const; + LabelHypothesis const& getWorstHypothesis() const; + + void resetStatistics(); + void logStatistics() const; + + /* + * Infer type of transition between two tokens based on whether each of them is blank + * and/or whether they are the same + */ + Nn::LabelScorer::TransitionType inferTransitionType(Nn::LabelIndex prevLabel, Nn::LabelIndex nextLabel, bool inRoot=false) const; + + /* + * Helper function for pruning to maxBeamSize + */ + void beamSizePruning(std::vector& extensions, size_t maxBeamSize) const; + + /* + * Helper function for pruning to scoreThreshold + */ + void scorePruning(std::vector& extensions, Score scoreThreshold) const; + + /* + * Helper function for recombination of hypotheses at the same point in the tree with the same scoring context and LM history + */ + void recombination(std::vector& hypotheses); + + /* + * Precompute information about the successor structure of each state in the search tree + * to avoid repeated computation during the decode steps + * stateSuccessorLookup_: contains a list of all state successors for the state at the corresponding index + * exitLookup_: contains a list of all exits for the state at the corresponding index + */ + // TODO make this more efficient, especially for states with only one exit (cf. AdvancedTreeSearch) + void createSuccessorLookups(); + + /* + * After reaching the segment end, go through the active hypotheses, only keep those + * which are at a word end (in the root state) and add the sentence end LM score. + * If no word-end hypotheses exist, use sentence-end fallback or construct an empty hypothesis + */ + void finalizeLmScoring(); +}; + +} // namespace Search + +#endif // TREE_TIMESYNC_BEAM_SEARCH_HH From cdd2b8f9ab8227fc3a8e256bafc16eb106ee089d Mon Sep 17 00:00:00 2001 From: Larissa Date: Mon, 7 Apr 2025 14:43:44 +0200 Subject: [PATCH 02/11] Format --- src/Search/Module.cc | 4 +- .../TreeTimesyncBeamSearch.cc | 76 +++++++++---------- .../TreeTimesyncBeamSearch.hh | 4 +- 3 files changed, 41 insertions(+), 43 deletions(-) diff --git a/src/Search/Module.cc b/src/Search/Module.cc index ad2553525..9878c7bec 100644 --- a/src/Search/Module.cc +++ b/src/Search/Module.cc @@ -17,8 +17,8 @@ #include #include #include "LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh" -#include "TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh" #include "TreeBuilder.hh" +#include "TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh" #ifdef MODULE_SEARCH_WFST #include #include @@ -119,7 +119,7 @@ SearchAlgorithmV2* Module_::createSearchAlgorithmV2(const Core::Configuration& c break; case TreeTimesyncBeamSearchType: searchAlgorithm = new Search::TreeTimesyncBeamSearch(config); - break; + break; default: Core::Application::us()->criticalError("Unknown search algorithm type: %d", searchTypeV2Param(config)); break; diff --git a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc index 10cdd589e..60a911938 100644 --- a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc +++ b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc @@ -22,8 +22,8 @@ #include #include #include -#include "Search/Traceback.hh" #include "Search/Module.hh" +#include "Search/Traceback.hh" namespace Search { @@ -51,7 +51,6 @@ TreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( lmHistory(extension.lmHistory), score(extension.score), trace() { - switch (extension.transitionType) { case Nn::LabelScorer::INITIAL_BLANK: case Nn::LabelScorer::INITIAL_LABEL: @@ -68,11 +67,11 @@ TreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( case Nn::LabelScorer::LABEL_TO_BLANK: if (base.trace->pronunciation != nullptr) { // A word has ended before and the first token of a new word was predicted -> start a new trace trace = Core::ref(new LatticeTrace( - base.trace, - extension.pron, - extension.timeframe + 1, - {base.trace->score.acoustic + (extension.score - base.score - extension.lmScore), base.trace->score.lm + extension.lmScore}, - {})); + base.trace, + extension.pron, + extension.timeframe + 1, + {base.trace->score.acoustic + (extension.score - base.score - extension.lmScore), base.trace->score.lm + extension.lmScore}, + {})); } else { // Word-end or within-word hypothesis and no word has ended before -> update the old trace trace = Core::ref(new LatticeTrace(*base.trace)); @@ -212,7 +211,7 @@ bool TreeTimesyncBeamSearch::setModelCombination(Speech::ModelCombination const& // Build the search tree log() << "Start building search tree"; - network_ = Core::ref(new PersistentStateTree(config, acousticModel_, lexicon_, std::bind(&Module_::createTreeBuilder, &Search::Module::instance(), std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5))); + network_ = Core::ref(new PersistentStateTree(config, acousticModel_, lexicon_, std::bind(&Module_::createTreeBuilder, &Search::Module::instance(), std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5))); std::unique_ptr builder = Search::Module::instance().createTreeBuilder(config, *lexicon_, *acousticModel_, *network_); builder->build(); log() << "Building finished"; @@ -251,8 +250,8 @@ void TreeTimesyncBeamSearch::reset() { beam_.clear(); beam_.push_back(LabelHypothesis()); beam_.front().scoringContext = labelScorer_->getInitialScoringContext(); - beam_.front().currentState = network_->rootState; - beam_.front().lmHistory = languageModel_->startHistory(); + beam_.front().currentState = network_->rootState; + beam_.front().lmHistory = languageModel_->startHistory(); finishedSegment_ = false; @@ -329,22 +328,22 @@ bool TreeTimesyncBeamSearch::decodeStep() { auto& hyp = beam_[hypIndex]; // Iterate over the successors of this hypothesis' current state in the tree - for (const auto& successorState: stateSuccessorLookup_[hyp.currentState]) { + for (const auto& successorState : stateSuccessorLookup_[hyp.currentState]) { Nn::LabelIndex tokenIdx = network_->structure.state(successorState).stateDesc.acousticModel; // If we want to force blank between repeated labels across words, a new word should not start with the same token as the previous word ended (except for blank itself) // If we don't force blank and we have a repeated label across words, we need to make sure to have label-to-Label as transition type - if (not (forceBlankAcrossWords_ and (hyp.currentState == network_->rootState) and (tokenIdx == hyp.currentToken) and (tokenIdx != blankLabelIndex_))) { + if (not(forceBlankAcrossWords_ and (hyp.currentState == network_->rootState) and (tokenIdx == hyp.currentToken) and (tokenIdx != blankLabelIndex_))) { auto transitionType = inferTransitionType(hyp.currentToken, tokenIdx, hyp.currentState == network_->rootState); extensions_.push_back( - {tokenIdx, - nullptr, - successorState, - hyp.lmHistory, - hyp.score, - 0.0, - 0, - transitionType, - hypIndex}); + {tokenIdx, + nullptr, + successorState, + hyp.lmHistory, + hyp.score, + 0.0, + 0, + transitionType, + hypIndex}); requests_.push_back({beam_[hypIndex].scoringContext, tokenIdx, transitionType}); } } @@ -371,7 +370,6 @@ bool TreeTimesyncBeamSearch::decodeStep() { clog() << Core::XmlOpen("search-step-stats"); } - /* * Prune set of possible within-word extensions by max beam size and possibly also by score. */ @@ -392,7 +390,7 @@ bool TreeTimesyncBeamSearch::decodeStep() { */ withinWordExtensions_.clear(); wordEndExtensions_.clear(); - for (const auto& extension: extensions_) { + for (const auto& extension : extensions_) { // If there is at least one state successor, keep it as within-word hypothesis if (not stateSuccessorLookup_[extension.state].empty()) { withinWordExtensions_.push_back(extension); @@ -400,23 +398,23 @@ bool TreeTimesyncBeamSearch::decodeStep() { std::vector exitList = exitLookup_[extension.state]; if (not exitList.empty()) { // Create one word-end hypothesis for each exit - for (const auto& exit: exitList) { - ExtensionCandidate wordEndExtension(extension); + for (const auto& exit : exitList) { + ExtensionCandidate wordEndExtension(extension); const Bliss::LemmaPronunciation* lemmaPron = lexicon_->lemmaPronunciation(exit.pronunciation); - const Bliss::Lemma* lemma = lemmaPron->lemma(); + const Bliss::Lemma* lemma = lemmaPron->lemma(); // Start from the root node (the exit's transit state) in the next step wordEndExtension.state = exit.transitState; - wordEndExtension.pron = lemmaPron; + wordEndExtension.pron = lemmaPron; if (lemma != lexicon_->specialLemma("blank")) { const Bliss::SyntacticTokenSequence sts = lemma->syntacticTokenSequence(); - const Bliss::SyntacticToken* st = sts.front(); + const Bliss::SyntacticToken* st = sts.front(); // Add the LM score and update the LM history Lm::Score lmScore = languageModel_->score(wordEndExtension.lmHistory, st); wordEndExtension.score += lmScore; - wordEndExtension.lmScore = lmScore; + wordEndExtension.lmScore = lmScore; wordEndExtension.lmHistory = languageModel_->extendedHistory(wordEndExtension.lmHistory, st); } wordEndExtensions_.push_back(wordEndExtension); @@ -446,7 +444,7 @@ bool TreeTimesyncBeamSearch::decodeStep() { extensions_.swap(withinWordExtensions_); extensions_.insert(extensions_.end(), wordEndExtensions_.begin(), wordEndExtensions_.end()); - for (auto const& extension: extensions_) { + for (auto const& extension : extensions_) { auto const& baseHyp = beam_[extension.baseHypIndex]; auto newScoringContext = labelScorer_->extendedScoringContext( @@ -479,7 +477,6 @@ bool TreeTimesyncBeamSearch::decodeStep() { beam_.swap(newBeam_); - if (logStepwiseStatistics_) { clog() << Core::XmlFull("best-hyp-score", getBestHypothesis().score); clog() << Core::XmlFull("worst-hyp-score", getWorstHypothesis().score); @@ -597,7 +594,7 @@ void TreeTimesyncBeamSearch::recombination(std::vectorstructure.stateCount()); for (u32 state = 1; state < network_->structure.stateCount(); ++state) { - std::vector stateList; // Collect the state successors of all nodes - std::vector exitList; // Collect the exits of all nodes + std::vector stateList; // Collect the state successors of all nodes + std::vector exitList; // Collect the exits of all nodes for (HMMStateNetwork::SuccessorIterator it = network_->structure.successors(state); it; ++it) { if (not it.isLabel()) { stateList.push_back(*it); - } else { + } + else { exitList.push_back(network_->exits[it.label()]); } } stateSuccessorLookup_[state] = stateList; - exitLookup_[state] = exitList; + exitLookup_[state] = exitList; // Retrieve the maximal number of exits a node in the tree can have to estimate the size of pre-allocated vectors if (exitList.size() > maxNumberOfExits_) { @@ -687,7 +685,7 @@ void TreeTimesyncBeamSearch::finalizeLmScoring() { log() << "Use sentence-end fallback"; // The trace of the unfinished word keeps an empty pronunciation, only the LM score is added for (size_t hypIndex = 0ul; hypIndex < beam_.size(); ++hypIndex) { - auto& hyp = beam_[hypIndex]; + auto& hyp = beam_[hypIndex]; Lm::Score sentenceEndScore = languageModel_->sentenceEndScore(hyp.lmHistory); hyp.score += sentenceEndScore; hyp.trace->score.lm += sentenceEndScore; @@ -697,9 +695,9 @@ void TreeTimesyncBeamSearch::finalizeLmScoring() { else { // Construct an empty hypothesis with a lattice containing only one empty pronunciation from start to end newBeam_.push_back(LabelHypothesis()); - newBeam_.front().trace->time = beam_.front().trace->time; // Retrieve the timeframe from any hyp in the old beam + newBeam_.front().trace->time = beam_.front().trace->time; // Retrieve the timeframe from any hyp in the old beam newBeam_.front().trace->pronunciation = nullptr; - newBeam_.front().trace->predecessor = Core::ref(new LatticeTrace(0, {0, 0}, {})); + newBeam_.front().trace->predecessor = Core::ref(new LatticeTrace(0, {0, 0}, {})); } } beam_.swap(newBeam_); diff --git a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh index c2be05881..9f6ae73ca 100644 --- a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh +++ b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh @@ -22,9 +22,9 @@ #include #include #include +#include #include #include -#include namespace Search { @@ -176,7 +176,7 @@ private: * Infer type of transition between two tokens based on whether each of them is blank * and/or whether they are the same */ - Nn::LabelScorer::TransitionType inferTransitionType(Nn::LabelIndex prevLabel, Nn::LabelIndex nextLabel, bool inRoot=false) const; + Nn::LabelScorer::TransitionType inferTransitionType(Nn::LabelIndex prevLabel, Nn::LabelIndex nextLabel, bool inRoot = false) const; /* * Helper function for pruning to maxBeamSize From e5953f2ed6f97bfec769df0b50c8a21b160b33b7 Mon Sep 17 00:00:00 2001 From: Larissa Date: Thu, 24 Apr 2025 18:34:34 +0200 Subject: [PATCH 03/11] Some fixes --- .../TreeTimesyncBeamSearch.cc | 19 +++++++------------ .../TreeTimesyncBeamSearch.hh | 6 +++--- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc index 60a911938..84688ab81 100644 --- a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc +++ b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc @@ -175,6 +175,7 @@ TreeTimesyncBeamSearch::TreeTimesyncBeamSearch(Core::Configuration const& config newBeam_(), requests_(), recombinedHypotheses_(), + maxNumberOfExits_(0), initializationTime_(), featureProcessingTime_(), scoringTime_(), @@ -276,21 +277,15 @@ void TreeTimesyncBeamSearch::finishSegment() { finalizeLmScoring(); } -void TreeTimesyncBeamSearch::putFeature(std::shared_ptr const& data, size_t featureSize) { +void TreeTimesyncBeamSearch::putFeature(Nn::DataView const& feature) { featureProcessingTime_.start(); - labelScorer_->addInput(data, featureSize); + labelScorer_->addInput(feature); featureProcessingTime_.stop(); } -void TreeTimesyncBeamSearch::putFeature(std::vector const& data) { +void TreeTimesyncBeamSearch::putFeatures(Nn::DataView const& features, size_t nTimesteps) { featureProcessingTime_.start(); - labelScorer_->addInput(data); - featureProcessingTime_.stop(); -} - -void TreeTimesyncBeamSearch::putFeatures(std::shared_ptr const& data, size_t timeSize, size_t featureSize) { - featureProcessingTime_.start(); - labelScorer_->addInputs(data, timeSize, featureSize); + labelScorer_->addInputs(features, nTimesteps); featureProcessingTime_.stop(); } @@ -642,8 +637,8 @@ void TreeTimesyncBeamSearch::recombination(std::vectorstructure.stateCount()); - exitLookup_.reserve(network_->structure.stateCount()); + stateSuccessorLookup_.resize(network_->structure.stateCount()); + exitLookup_.resize(network_->structure.stateCount()); for (u32 state = 1; state < network_->structure.stateCount(); ++state) { std::vector stateList; // Collect the state successors of all nodes diff --git a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh index 9f6ae73ca..0f44eb8d7 100644 --- a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh +++ b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -108,9 +109,8 @@ public: void reset() override; void enterSegment(Bliss::SpeechSegment const* = nullptr) override; void finishSegment() override; - void putFeature(std::shared_ptr const& data, size_t featureSize) override; - void putFeature(std::vector const& data) override; - void putFeatures(std::shared_ptr const& data, size_t timeSize, size_t featureSize) override; + void putFeature(Nn::DataView const& feature) override; + void putFeatures(Nn::DataView const& features, size_t nTimesteps) override; Core::Ref getCurrentBestTraceback() const override; Core::Ref getCurrentBestWordLattice() const override; bool decodeStep() override; From fb22b38d1fdb7d602e0a10c72053bfa12ccab778 Mon Sep 17 00:00:00 2001 From: Larissa Date: Tue, 20 May 2025 10:03:29 +0200 Subject: [PATCH 04/11] Correct vector allocation --- src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc index 84688ab81..9f6131542 100644 --- a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc +++ b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc @@ -223,7 +223,7 @@ bool TreeTimesyncBeamSearch::setModelCombination(Speech::ModelCombination const& // Pre-allocate vectors // If maxWordEndBeamSize_ is not set, we need the maximum number of exits a node can have for estimating the max. size of the vectors - int maxWordEnds = maxWordEndBeamSize_ == std::numeric_limits::max() ? maxNumberOfExits_ : maxWordEndBeamSize_; + int maxWordEnds = maxWordEndBeamSize_ == std::numeric_limits::max() ? (maxNumberOfExits_ * maxBeamSize_) : maxWordEndBeamSize_; // The beam contains all within-word and word-end hypotheses which survived pruning beam_.reserve(maxBeamSize_ + maxWordEnds); From ba116c8ebdf8bc2fd9a701881571b9f5694bca14 Mon Sep 17 00:00:00 2001 From: Larissa Date: Fri, 6 Jun 2025 16:15:37 +0200 Subject: [PATCH 05/11] Introduce Label Scorer cache cleanup --- .../TreeTimesyncBeamSearch.cc | 36 +++++++++++++++---- .../TreeTimesyncBeamSearch.hh | 6 +++- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc index 9f6131542..e21a3667b 100644 --- a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc +++ b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -93,6 +94,8 @@ TreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( trace->score.acoustic = base.trace->score.acoustic + (extension.score - base.score - extension.lmScore); trace->score.lm = base.trace->score.lm + extension.lmScore; break; + default: + defect(); // Unexpected transition type which can not be produced by `inferTransitionType` } } @@ -157,6 +160,11 @@ const Core::ParameterBool TreeTimesyncBeamSearch::paramLogStepwiseStatistics( "Log statistics about the beam at every search step.", false); +const Core::ParameterBool TreeTimesyncBeamSearch::paramCacheCleanupInterval( + "cache-cleanup-interval", + "Interval of search steps after which buffered inputs that are not needed anymore get cleaned up.", + 10); + TreeTimesyncBeamSearch::TreeTimesyncBeamSearch(Core::Configuration const& config) : Core::Component(config), SearchAlgorithmV2(config), @@ -168,6 +176,7 @@ TreeTimesyncBeamSearch::TreeTimesyncBeamSearch(Core::Configuration const& config forceBlankAcrossWords_(paramForceBlankAcrossWords(config)), sentenceEndFallback_(paramSentenceEndFallBack(config)), logStepwiseStatistics_(paramLogStepwiseStatistics(config)), + cacheCleanupInterval_(paramCacheCleanupInterval(config)), debugChannel_(config, "debug"), labelScorer_(), beam_(), @@ -185,6 +194,7 @@ TreeTimesyncBeamSearch::TreeTimesyncBeamSearch(Core::Configuration const& config numWordEndHypsAfterScorePruning_("num-word-end-hyps-after-score-pruning"), numWordEndHypsAfterBeamPruning_("num-word-end-hyps-after-beam-pruning"), numActiveHyps_("num-active-hyps"), + currentSearchStep_(0ul), finishedSegment_(false) { if (wordEndScoreThreshold_ <= 1.0) { if (scoreThreshold_ == Core::Type::max) { @@ -254,7 +264,8 @@ void TreeTimesyncBeamSearch::reset() { beam_.front().currentState = network_->rootState; beam_.front().lmHistory = languageModel_->startHistory(); - finishedSegment_ = false; + currentSearchStep_ = 0ul; + finishedSegment_ = false; initializationTime_.stop(); } @@ -457,22 +468,33 @@ bool TreeTimesyncBeamSearch::decodeStep() { recombination(newBeam_); numActiveHyps_ += newBeam_.size(); - if (logStepwiseStatistics_) { - clog() << Core::XmlFull("active-hyps", newBeam_.size()); + /* + * Clean up label scorer caches. + */ + if (++currentSearchStep_ % cacheCleanupInterval_ == 0) { + Core::CollapsedVector activeContexts; + for (auto const& hyp : newBeam_) { + activeContexts.push_back(hyp.scoringContext); + } + labelScorer_->cleanupCaches(activeContexts); } + /* + * Log statistics about the new beam after this step. + */ + beam_.swap(newBeam_); + if (debugChannel_.isOpen()) { std::stringstream ss; - for (size_t hypIdx = 0ul; hypIdx < newBeam_.size(); ++hypIdx) { - ss << "Hypothesis " << hypIdx + 1ul << ": " << newBeam_[hypIdx].toString() << "\n"; + for (size_t hypIdx = 0ul; hypIdx < beam_.size(); ++hypIdx) { + ss << "Hypothesis " << hypIdx + 1ul << ": " << beam_[hypIdx].toString() << "\n"; } ss << "\n"; debugChannel_ << ss.str(); } - beam_.swap(newBeam_); - if (logStepwiseStatistics_) { + clog() << Core::XmlFull("active-hyps", beam_.size()); clog() << Core::XmlFull("best-hyp-score", getBestHypothesis().score); clog() << Core::XmlFull("worst-hyp-score", getWorstHypothesis().score); clog() << Core::XmlClose("search-step-stats"); diff --git a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh index 0f44eb8d7..1218b9cf0 100644 --- a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh +++ b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh @@ -98,6 +98,7 @@ public: static const Core::ParameterBool paramForceBlankAcrossWords; static const Core::ParameterBool paramSentenceEndFallBack; static const Core::ParameterBool paramLogStepwiseStatistics; + static const Core::ParameterBool paramCacheCleanupInterval; TreeTimesyncBeamSearch(Core::Configuration const&); @@ -131,6 +132,8 @@ private: bool logStepwiseStatistics_; + size_t cacheCleanupInterval_; + Core::Channel debugChannel_; Core::Ref labelScorer_; @@ -164,7 +167,8 @@ private: Core::Statistics numWordEndHypsAfterBeamPruning_; Core::Statistics numActiveHyps_; - bool finishedSegment_; + size_t currentSearchStep_; + bool finishedSegment_; LabelHypothesis const& getBestHypothesis() const; LabelHypothesis const& getWorstHypothesis() const; From 099588cea1fa67793d1424ab0f8a84d1b1c84a0a Mon Sep 17 00:00:00 2001 From: Larissa Date: Sun, 8 Jun 2025 18:03:57 +0200 Subject: [PATCH 06/11] Integrate LM lookahead into tree search --- src/Search/LanguageModelLookahead.cc | 4 +- .../TreeTimesyncBeamSearch.cc | 135 +++++++++++++++++- .../TreeTimesyncBeamSearch.hh | 72 +++++++--- 3 files changed, 188 insertions(+), 23 deletions(-) diff --git a/src/Search/LanguageModelLookahead.cc b/src/Search/LanguageModelLookahead.cc index 5044991d6..23edda2ac 100644 --- a/src/Search/LanguageModelLookahead.cc +++ b/src/Search/LanguageModelLookahead.cc @@ -649,6 +649,8 @@ void LanguageModelLookahead::ConstructionTree::build(HMMStateNetwork const& for (HMMStateNetwork::SuccessorIterator target = tree_.successors(node); target; ++target) { if (not target.isLabel()) { + if (*target == node) + continue; build(*target, depth + 1); successors.push_back(*target); } @@ -743,7 +745,7 @@ void LanguageModelLookahead::ConstructionTree::build(HMMStateNetwork const& collected[node] = -2; for (HMMStateNetwork::SuccessorIterator edges = tree_.successors(node); edges; ++edges) { - if (not edges.isLabel()) { + if (not edges.isLabel() and *edges != node) { int depth2 = collectTopologicalStates(*edges, depth + 1, topologicalStates, collected); if (depth2 - 1 < depth) { depth = depth2 - 1; diff --git a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc index e21a3667b..8e92476df 100644 --- a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc +++ b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc @@ -21,6 +21,8 @@ #include #include #include +#include +#include #include #include #include "Search/Module.hh" @@ -38,7 +40,10 @@ TreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis() : scoringContext(), currentToken(Core::Type::max), currentState(invalidTreeNodeIndex), + lookahead(), lmHistory(), + lookaheadHistory(), + fullLookaheadHistory(), score(0.0), trace(Core::ref(new LatticeTrace(0, {0, 0}, {}))) {} @@ -49,7 +54,10 @@ TreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( : scoringContext(newScoringContext), currentToken(extension.nextToken), currentState(extension.state), + lookahead(extension.lookahead), lmHistory(extension.lmHistory), + lookaheadHistory(extension.lookaheadHistory), + fullLookaheadHistory(extension.fullLookaheadHistory), score(extension.score), trace() { switch (extension.transitionType) { @@ -150,6 +158,21 @@ const Core::ParameterBool TreeTimesyncBeamSearch::paramForceBlankAcrossWords( "Require a blank label between identical labels at word end and word begin.", false); +const Core::ParameterBool TreeTimesyncBeamSearch::paramLmLookahead( + "lm-lookahead", + "Enable language model lookahead.", + true); + +const Core::ParameterBool TreeTimesyncBeamSearch::paramSeparateLookaheadLm( + "separate-lookahead-lm", + "Use a separate LM for lookahead.", + false); + +const Core::ParameterBool TreeTimesyncBeamSearch::paramSparseLmLookAhead( + "sparse-lm-lookahead", + "Use sparse n-gram LM lookahead.", + true); + const Core::ParameterBool TreeTimesyncBeamSearch::paramSentenceEndFallBack( "sentence-end-fall-back", "Allow for fallback solution if no active word-end hypothesis exists at the end of a segment.", @@ -174,6 +197,10 @@ TreeTimesyncBeamSearch::TreeTimesyncBeamSearch(Core::Configuration const& config wordEndScoreThreshold_(paramWordEndScoreThreshold(config)), collapseRepeatedLabels_(paramCollapseRepeatedLabels(config)), forceBlankAcrossWords_(paramForceBlankAcrossWords(config)), + enableLmLookahead_(paramLmLookahead(config)), + separateLookaheadLm_(paramSeparateLookaheadLm(config)), + sparseLmLookahead_(paramSparseLmLookAhead(config)), + lmLookaheadCache_(1000), sentenceEndFallback_(paramSentenceEndFallBack(config)), logStepwiseStatistics_(paramLogStepwiseStatistics(config)), cacheCleanupInterval_(paramCacheCleanupInterval(config)), @@ -230,6 +257,34 @@ bool TreeTimesyncBeamSearch::setModelCombination(Speech::ModelCombination const& // Create look-ups for state successors and exits of each state createSuccessorLookups(); + // Set lookahead LM + if (enableLmLookahead_) { + if (separateLookaheadLm_) { + log() << "Use separate lookahead LM"; + lookaheadLm_ = Lm::Module::instance().createScaledLanguageModel(select("lm-lookahead"), lexicon_); + } + else if (languageModel_->lookaheadLanguageModel().get() != nullptr) { + lookaheadLm_ = Core::Ref(new Lm::LanguageModelScaling(select("lookahead-lm"), + Core::Ref(const_cast(languageModel_->lookaheadLanguageModel().get())))); + } + else { + lookaheadLm_ = languageModel_; + } + + if (sparseLmLookahead_ && !dynamic_cast(lookaheadLm_->unscaled().get())) { + warning() << "Not using sparse LM lookahead, because the LM is not a backing-off LM."; + sparseLmLookahead_ = false; + } + + lmLookahead_ = new LanguageModelLookahead(Core::Configuration(config, "lm-lookahead"), + modelCombination.pronunciationScale(), + lookaheadLm_, + network_->structure, + network_->rootState, + network_->exits, + acousticModel_); + } + // Pre-allocate vectors // If maxWordEndBeamSize_ is not set, we need the maximum number of exits a node can have for estimating the max. size of the vectors @@ -256,6 +311,7 @@ void TreeTimesyncBeamSearch::reset() { initializationTime_.start(); labelScorer_->reset(); + lmLookaheadCache_.clear(); // Reset beam to a single empty hypothesis beam_.clear(); @@ -264,6 +320,11 @@ void TreeTimesyncBeamSearch::reset() { beam_.front().currentState = network_->rootState; beam_.front().lmHistory = languageModel_->startHistory(); + if (enableLmLookahead_) { + beam_.front().lookaheadHistory = lookaheadLm_->startHistory(); + beam_.front().fullLookaheadHistory = lookaheadLm_->startHistory(); + } + currentSearchStep_ = 0ul; finishedSegment_ = false; @@ -344,7 +405,10 @@ bool TreeTimesyncBeamSearch::decodeStep() { {tokenIdx, nullptr, successorState, + hyp.lookahead, hyp.lmHistory, + hyp.lookaheadHistory, + hyp.fullLookaheadHistory, hyp.score, 0.0, 0, @@ -370,6 +434,14 @@ bool TreeTimesyncBeamSearch::decodeStep() { for (size_t requestIdx = 0ul; requestIdx < extensions_.size(); ++requestIdx) { extensions_[requestIdx].score += result->scores[requestIdx]; extensions_[requestIdx].timeframe = result->timeframes[requestIdx]; + + // Add the LM lookahead score to the extensions' scores for pruning + // Make sure not to calculate the lookahead score for the blank lemma which is reachable from the root + if (enableLmLookahead_ and not(beam_[extensions_[requestIdx].baseHypIndex].currentState == network_->rootState and extensions_[requestIdx].nextToken == blankLabelIndex_)) { + Score lookaheadScore = getLmLookaheadScore(extensions_[requestIdx]); + extensions_[requestIdx].lmScore = lookaheadScore; + extensions_[requestIdx].score += lookaheadScore; + } } if (logStepwiseStatistics_) { @@ -396,7 +468,13 @@ bool TreeTimesyncBeamSearch::decodeStep() { */ withinWordExtensions_.clear(); wordEndExtensions_.clear(); - for (const auto& extension : extensions_) { + for (auto& extension : extensions_) { + if (enableLmLookahead_) { + // Subtract the LM lookahead score again + extension.score -= extension.lmScore; + extension.lmScore = 0; + } + // If there is at least one state successor, keep it as within-word hypothesis if (not stateSuccessorLookup_[extension.state].empty()) { withinWordExtensions_.push_back(extension); @@ -443,6 +521,20 @@ bool TreeTimesyncBeamSearch::decodeStep() { clog() << Core::XmlFull("num-word-end-hyps-after-beam-pruning", wordEndExtensions_.size()); } + // If the lookahead history has changed, prepare new lookahead for the next timeframe + if (enableLmLookahead_) { + for (auto& wordEndExtension : wordEndExtensions_) { + const Bliss::SyntacticToken* st = wordEndExtension.pron->lemma()->syntacticTokenSequence().front(); + Lm::History newLookaheadHistory = lookaheadLm_->extendedHistory(wordEndExtension.fullLookaheadHistory, st); + + if (!(newLookaheadHistory == wordEndExtension.lookaheadHistory)) { + getLmLookahead(wordEndExtension.lookahead, newLookaheadHistory); + wordEndExtension.lookaheadHistory = newLookaheadHistory; + wordEndExtension.fullLookaheadHistory = newLookaheadHistory; + } + } + } + /* * Create new beam from surviving extensions. */ @@ -575,6 +667,47 @@ Nn::LabelScorer::TransitionType TreeTimesyncBeamSearch::inferTransitionType(Nn:: } } +void TreeTimesyncBeamSearch::getLmLookahead(LanguageModelLookahead::ContextLookaheadReference& lookahead, Lm::History history) { + if (lmLookaheadCache_.contains(history)) { + lookahead = lmLookaheadCache_[history]; + } + else { + lookahead = lmLookahead_->getLookahead(history); + lmLookahead_->fill(lookahead, sparseLmLookahead_); + lmLookaheadCache_.put(history, lookahead); + } +} + +Score TreeTimesyncBeamSearch::getLmLookaheadScore(TreeTimesyncBeamSearch::ExtensionCandidate& extension) { + if (!extension.lookahead) { + getLmLookahead(extension.lookahead, extension.lookaheadHistory); + } + + Score lookaheadScore = 0; + bool scoreFound = false; + do { + if (extension.lookahead->isSparse()) { // Non-sparse lookahead + auto lookaheadHash = lmLookahead_->lookaheadHash(extension.state); + scoreFound = extension.lookahead->getScoreForLookAheadHashSparse(lookaheadHash, lookaheadScore); + } + else { // Sparse lookahead + auto lookaheadId = lmLookahead_->lookaheadId(extension.state); + lookaheadScore = extension.lookahead->scoreForLookAheadIdNormal(lookaheadId); + scoreFound = true; + } + + if (!scoreFound) { // No lookahead table entry, use back-off + const Lm::BackingOffLm* lm = dynamic_cast(lookaheadLm_->unscaled().get()); + lookaheadScore += lm->getBackOffScore(extension.lookaheadHistory); + // Reduce the history and retrieve the corresponding lookahead table + extension.lookaheadHistory = lm->reducedHistory(extension.lookaheadHistory, lm->historyLength(extension.lookaheadHistory) - 1); + getLmLookahead(extension.lookahead, extension.lookaheadHistory); + } + } while (!scoreFound); + + return lookaheadScore; +} + void TreeTimesyncBeamSearch::beamSizePruning(std::vector& extensions, size_t maxBeamSize) const { if (extensions.size() <= maxBeamSize) { return; diff --git a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh index 1218b9cf0..c4f35673f 100644 --- a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh +++ b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -26,6 +27,7 @@ #include #include #include +#include "Search/LanguageModelLookahead.hh" namespace Search { @@ -33,6 +35,7 @@ namespace Search { * Simple time synchronous beam search algorithm on a search tree built by the CtcTreeBuilder oder RnaTreeBuilder. * At a word end, a language model score is added to the hypothesis score, * if no language model should be used, the LM-scale has to be set to 0.0. + * Full or sparse language model lookahead can optionally be used with the same or with a separate LM. * Supports global or separate pruning of within-word and word-end hypotheses * by max beam-size and by score difference to the best hypothesis. * Uses a LabelScorer to context initialization/extension and scoring. @@ -50,15 +53,18 @@ protected: * Possible extension for some label hypothesis in the beam */ struct ExtensionCandidate { - Nn::LabelIndex nextToken; // Proposed token to extend the hypothesis with - const Bliss::LemmaPronunciation* pron; // Pronunciation of the lemma if we are at a word end - StateId state; // State in the search tree of this extension - Lm::History lmHistory; // LM history of the hypothesis, possibly extended at a word end - Score score; // Would-be total score of the full hypothesis after extension (incl. LM score) - Score lmScore; // Would-be LM score of a word-end hypothesis after extension - Search::TimeframeIndex timeframe; // Timestamp of `nextToken` for traceback - Nn::LabelScorer::TransitionType transitionType; // Type of transition toward `nextToken` - size_t baseHypIndex; // Index of base hypothesis in global beam + Nn::LabelIndex nextToken; // Proposed token to extend the hypothesis with + const Bliss::LemmaPronunciation* pron; // Pronunciation of the lemma if we are at a word end + StateId state; // State in the search tree of this extension + LanguageModelLookahead::ContextLookaheadReference lookahead; // LM-lookahead table, possibly updated at a word end + Lm::History lmHistory; // LM history of the hypothesis, possibly extended at a word end + Lm::History lookaheadHistory; // LM history used for the lookahead, may be reduced + Lm::History fullLookaheadHistory; // The full/unreduced LM history for the lookahead which will be expanded at a word end + Score score; // Would-be total score of the full hypothesis after extension (incl. LM score) + Score lmScore; // Would-be LM score of a word-end hypothesis after extension + Search::TimeframeIndex timeframe; // Timestamp of `nextToken` for traceback + Nn::LabelScorer::TransitionType transitionType; // Type of transition toward `nextToken` + size_t baseHypIndex; // Index of base hypothesis in global beam bool operator<(ExtensionCandidate const& other) { return score < other.score; @@ -69,12 +75,16 @@ protected: * Struct containing all information about a single hypothesis in the beam */ struct LabelHypothesis { - Nn::ScoringContextRef scoringContext; // Context to compute scores based on this hypothesis - Nn::LabelIndex currentToken; // Most recent token in associated label sequence (useful to infer transition type) - StateId currentState; // Current state in the search tree - Lm::History lmHistory; // Language model history - Score score; // Full score of the hypothesis - Core::Ref trace; // Associated trace for traceback or lattice building of hypothesis + Nn::ScoringContextRef scoringContext; // Context to compute scores based on this hypothesis + Nn::LabelIndex currentToken; // Most recent token in associated label sequence (useful to infer transition type) + StateId currentState; // Current state in the search tree + LanguageModelLookahead::ContextLookaheadReference lookahead; // LM-lookahead table + Lm::History lmHistory; // Language model history + Lm::History lookaheadHistory; // LM history used for the lookahead, may be reduced + Lm::History fullLookaheadHistory; // The full/unreduced LM history for the lookahead + Score score; // Full score of the hypothesis + Core::Ref trace; // Associated trace for traceback or lattice building of hypothesis + size_t baseHypIndex; LabelHypothesis(); LabelHypothesis(LabelHypothesis const& base, ExtensionCandidate const& extension, Nn::ScoringContextRef const& newScoringContext); @@ -96,6 +106,9 @@ public: static const Core::ParameterFloat paramWordEndScoreThreshold; static const Core::ParameterBool paramCollapseRepeatedLabels; static const Core::ParameterBool paramForceBlankAcrossWords; + static const Core::ParameterBool paramLmLookahead; + static const Core::ParameterBool paramSeparateLookaheadLm; + static const Core::ParameterBool paramSparseLmLookAhead; static const Core::ParameterBool paramSentenceEndFallBack; static const Core::ParameterBool paramLogStepwiseStatistics; static const Core::ParameterBool paramCacheCleanupInterval; @@ -128,6 +141,12 @@ private: bool collapseRepeatedLabels_; bool forceBlankAcrossWords_; + bool enableLmLookahead_; + bool separateLookaheadLm_; + bool sparseLmLookahead_; + LanguageModelLookahead* lmLookahead_; + Core::FIFOCache lmLookaheadCache_; + bool sentenceEndFallback_; bool logStepwiseStatistics_; @@ -136,12 +155,13 @@ private: Core::Channel debugChannel_; - Core::Ref labelScorer_; - Bliss::LexiconRef lexicon_; - Core::Ref network_; - Core::Ref acousticModel_; - Core::Ref languageModel_; - std::vector beam_; + Core::Ref labelScorer_; + Bliss::LexiconRef lexicon_; + Core::Ref network_; + Core::Ref acousticModel_; + Core::Ref languageModel_; + Core::Ref lookaheadLm_; + std::vector beam_; // Pre-allocated intermediate vectors std::vector extensions_; @@ -182,6 +202,16 @@ private: */ Nn::LabelScorer::TransitionType inferTransitionType(Nn::LabelIndex prevLabel, Nn::LabelIndex nextLabel, bool inRoot = false) const; + /* + * Retrieve the LM lookahead for the given history from cache or compute and cache it if missing + */ + void getLmLookahead(LanguageModelLookahead::ContextLookaheadReference& lookahead, Lm::History history); + + /* + * Compute the sparse or non-sparse LM lookahead score for an extension's state and history, with back-off if needed + */ + Score getLmLookaheadScore(TreeTimesyncBeamSearch::ExtensionCandidate& extension); + /* * Helper function for pruning to maxBeamSize */ From 6f0fd542586f1a263742d24874decf34e32a18a4 Mon Sep 17 00:00:00 2001 From: Larissa Date: Mon, 15 Sep 2025 17:25:45 +0200 Subject: [PATCH 07/11] Sync with changes in TreeTimesyncBeamSearch --- .../TreeTimesyncBeamSearch.cc | 127 +++++++++++++++++- .../TreeTimesyncBeamSearch.hh | 60 ++++++--- 2 files changed, 169 insertions(+), 18 deletions(-) diff --git a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc index 240d63b51..4de389c31 100644 --- a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc +++ b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc @@ -21,6 +21,8 @@ #include #include #include +#include +#include #include #include #include "Search/Module.hh" @@ -38,8 +40,12 @@ TreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis() : scoringContext(), currentToken(Nn::invalidLabelIndex), currentState(invalidTreeNodeIndex), + lookahead(), lmHistory(), + lookaheadHistory(), + fullLookaheadHistory(), score(0.0), + lookaheadScore(0.0), trace(Core::ref(new LatticeTrace(0, {0, 0}, {}))) {} TreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( @@ -49,8 +55,12 @@ TreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( : scoringContext(newScoringContext), currentToken(extension.nextToken), currentState(extension.state), + lookahead(extension.lookahead), lmHistory(extension.lmHistory), + lookaheadHistory(extension.lookaheadHistory), + fullLookaheadHistory(extension.fullLookaheadHistory), score(extension.score), + lookaheadScore(extension.lmScore), trace(base.trace) { if (extension.pron != nullptr) { // Word-end hypothesis -> update base trace and start a new trace for the next word auto completedTrace = Core::ref(new LatticeTrace(*base.trace)); @@ -115,6 +125,21 @@ const Core::ParameterBool TreeTimesyncBeamSearch::paramCollapseRepeatedLabels( "Collapse repeated emission of the same label into one output. If false, every emission is treated like a new output.", false); +const Core::ParameterBool TreeTimesyncBeamSearch::paramLmLookahead( + "lm-lookahead", + "Enable language model lookahead.", + false); + +const Core::ParameterBool TreeTimesyncBeamSearch::paramSeparateLookaheadLm( + "separate-lookahead-lm", + "Use a separate LM for lookahead.", + false); + +const Core::ParameterBool TreeTimesyncBeamSearch::paramSparseLmLookAhead( + "sparse-lm-lookahead", + "Use sparse n-gram LM lookahead.", + true); + const Core::ParameterBool TreeTimesyncBeamSearch::paramSentenceEndFallBack( "sentence-end-fall-back", "Allow for fallback solution if no active word-end hypothesis exists at the end of a segment.", @@ -140,6 +165,9 @@ TreeTimesyncBeamSearch::TreeTimesyncBeamSearch(Core::Configuration const& config cacheCleanupInterval_(paramCacheCleanupInterval(config)), useBlank_(), collapseRepeatedLabels_(paramCollapseRepeatedLabels(config)), + enableLmLookahead_(paramLmLookahead(config)), + separateLookaheadLm_(paramSeparateLookaheadLm(config)), + sparseLmLookahead_(paramSparseLmLookAhead(config)), sentenceEndFallback_(paramSentenceEndFallBack(config)), logStepwiseStatistics_(paramLogStepwiseStatistics(config)), labelScorer_(), @@ -215,6 +243,34 @@ bool TreeTimesyncBeamSearch::setModelCombination(Speech::ModelCombination const& // Create look-ups for state successors and exits of each state createSuccessorLookups(); + // Set lookahead LM + if (enableLmLookahead_) { + if (separateLookaheadLm_) { + log() << "Use separate lookahead LM"; + lookaheadLm_ = Lm::Module::instance().createScaledLanguageModel(select("lm-lookahead"), lexicon_); + } + else if (languageModel_->lookaheadLanguageModel().get() != nullptr) { + lookaheadLm_ = Core::Ref(new Lm::LanguageModelScaling(select("lookahead-lm"), + Core::Ref(const_cast(languageModel_->lookaheadLanguageModel().get())))); + } + else { + lookaheadLm_ = languageModel_; + } + + if (sparseLmLookahead_ && !dynamic_cast(lookaheadLm_->unscaled().get())) { + warning() << "Not using sparse LM lookahead, because the LM is not a backing-off LM."; + sparseLmLookahead_ = false; + } + + lmLookahead_ = new LanguageModelLookahead(Core::Configuration(config, "lm-lookahead"), + modelCombination.pronunciationScale(), + lookaheadLm_, + network_->structure, + network_->rootState, + network_->exits, + acousticModel_); + } + reset(); // Create global cache @@ -240,6 +296,11 @@ void TreeTimesyncBeamSearch::reset() { beam_.front().currentState = network_->rootState; beam_.front().lmHistory = languageModel_->startHistory(); + if (enableLmLookahead_) { + beam_.front().lookaheadHistory = lookaheadLm_->startHistory(); + beam_.front().fullLookaheadHistory = lookaheadLm_->startHistory(); + } + currentSearchStep_ = 0ul; finishedSegment_ = false; @@ -330,7 +391,10 @@ bool TreeTimesyncBeamSearch::decodeStep() { {tokenIdx, nullptr, successorState, + hyp.lookahead, hyp.lmHistory, + hyp.lookaheadHistory, + hyp.fullLookaheadHistory, hyp.score, 0.0, 0, @@ -355,6 +419,14 @@ bool TreeTimesyncBeamSearch::decodeStep() { for (size_t requestIdx = 0ul; requestIdx < extensions_.size(); ++requestIdx) { extensions_[requestIdx].score += result->scores[requestIdx]; extensions_[requestIdx].timeframe = result->timeframes[requestIdx]; + + // Add the LM lookahead score to the extensions' scores for pruning + // Make sure not to calculate the lookahead score for the blank lemma which is reachable from the root + if (enableLmLookahead_ and not(beam_[extensions_[requestIdx].baseHypIndex].currentState == network_->rootState and extensions_[requestIdx].nextToken == blankLabelIndex_)) { + Score lookaheadScore = getLmLookaheadScore(extensions_[requestIdx]); + extensions_[requestIdx].lmScore = lookaheadScore; + extensions_[requestIdx].score += lookaheadScore; + } } if (logStepwiseStatistics_) { @@ -404,6 +476,12 @@ bool TreeTimesyncBeamSearch::decodeStep() { for (size_t hypIndex = 0ul; hypIndex < newBeam_.size(); ++hypIndex) { auto& hyp = newBeam_[hypIndex]; + if (enableLmLookahead_) { + // Subtract the LM lookahead score again + hyp.score -= hyp.lookaheadScore; + hyp.lookaheadScore = 0; + } + std::vector exitList = exitLookup_[hyp.currentState]; if (not exitList.empty()) { // Create one word-end hypothesis for each exit @@ -414,7 +492,10 @@ bool TreeTimesyncBeamSearch::decodeStep() { ExtensionCandidate wordEndExtension{hyp.currentToken, lemmaPron, exit.transitState, // Start from the root node (the exit's transit state) in the next step + hyp.lookahead, hyp.lmHistory, + hyp.lookaheadHistory, + hyp.fullLookaheadHistory, hyp.score, 0.0, static_cast(currentSearchStep_), @@ -444,7 +525,7 @@ bool TreeTimesyncBeamSearch::decodeStep() { clog() << Core::XmlFull("num-word-end-hyps-after-score-pruning", extensions_.size()); } - // Create new word-end label hypotheses from word-end extension candidates and update the LM history + // Create new word-end label hypotheses from word-end extension candidates, update the LM history and prepare the new lookahead if its history has changed wordEndHypotheses_.clear(); for (auto& extension : extensions_) { const Bliss::Lemma* lemma = extension.pron->lemma(); @@ -452,6 +533,13 @@ bool TreeTimesyncBeamSearch::decodeStep() { const Bliss::SyntacticTokenSequence sts = lemma->syntacticTokenSequence(); const Bliss::SyntacticToken* st = sts.front(); extension.lmHistory = languageModel_->extendedHistory(extension.lmHistory, st); + Lm::History newLookaheadHistory = lookaheadLm_->extendedHistory(extension.fullLookaheadHistory, st); + + if (!(newLookaheadHistory == extension.lookaheadHistory)) { + getLmLookahead(extension.lookahead, newLookaheadHistory); + extension.lookaheadHistory = newLookaheadHistory; + extension.fullLookaheadHistory = newLookaheadHistory; + } } auto const& baseHyp = newBeam_[extension.baseHypIndex]; @@ -689,6 +777,41 @@ void TreeTimesyncBeamSearch::recombination(std::vectorgetLookahead(history); + lmLookahead_->fill(lookahead, sparseLmLookahead_); +} + +Score TreeTimesyncBeamSearch::getLmLookaheadScore(TreeTimesyncBeamSearch::ExtensionCandidate& extension) { + if (!extension.lookahead) { + getLmLookahead(extension.lookahead, extension.lookaheadHistory); + } + + Score lookaheadScore = 0; + bool scoreFound = false; + do { + if (extension.lookahead->isSparse()) { // Non-sparse lookahead + auto lookaheadHash = lmLookahead_->lookaheadHash(extension.state); + scoreFound = extension.lookahead->getScoreForLookAheadHashSparse(lookaheadHash, lookaheadScore); + } + else { // Sparse lookahead + auto lookaheadId = lmLookahead_->lookaheadId(extension.state); + lookaheadScore = extension.lookahead->scoreForLookAheadIdNormal(lookaheadId); + scoreFound = true; + } + + if (!scoreFound) { // No lookahead table entry, use back-off + const Lm::BackingOffLm* lm = dynamic_cast(lookaheadLm_->unscaled().get()); + lookaheadScore += lm->getBackOffScore(extension.lookaheadHistory); + // Reduce the history and retrieve the corresponding lookahead table + extension.lookaheadHistory = lm->reducedHistory(extension.lookaheadHistory, lm->historyLength(extension.lookaheadHistory) - 1); + getLmLookahead(extension.lookahead, extension.lookaheadHistory); + } + } while (!scoreFound); + + return lookaheadScore; +} + void TreeTimesyncBeamSearch::createSuccessorLookups() { stateSuccessorLookup_.resize(network_->structure.stateCount()); exitLookup_.resize(network_->structure.stateCount()); @@ -746,4 +869,4 @@ void TreeTimesyncBeamSearch::finalizeLmScoring() { beam_.swap(newBeam_); } -} // namespace Search +} // namespace Search \ No newline at end of file diff --git a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh index 09672d81e..c7b14afc2 100644 --- a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh +++ b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh @@ -26,6 +26,7 @@ #include #include #include +#include "Search/LanguageModelLookahead.hh" namespace Search { @@ -33,6 +34,7 @@ namespace Search { * Simple time synchronous beam search algorithm on a search tree built by a TreeBuilder. * At a word end, a language model score is added to the hypothesis score, * if no language model should be used, the LM-scale has to be set to 0.0. + * Full or sparse language model lookahead can optionally be used with the same or with a separate LM. * Supports global or separate pruning of within-word and word-end hypotheses * by max beam-size and by score difference to the best hypothesis. * Uses a LabelScorer to context initialization/extension and scoring. @@ -48,6 +50,9 @@ public: static const Core::ParameterFloat paramScoreThreshold; static const Core::ParameterFloat paramWordEndScoreThreshold; static const Core::ParameterBool paramCollapseRepeatedLabels; + static const Core::ParameterBool paramLmLookahead; + static const Core::ParameterBool paramSeparateLookaheadLm; + static const Core::ParameterBool paramSparseLmLookAhead; static const Core::ParameterBool paramSentenceEndFallBack; static const Core::ParameterBool paramLogStepwiseStatistics; static const Core::ParameterBool paramCacheCleanupInterval; @@ -73,15 +78,18 @@ protected: * Possible extension for some label hypothesis in the beam */ struct ExtensionCandidate { - Nn::LabelIndex nextToken; // Proposed token to extend the hypothesis with - const Bliss::LemmaPronunciation* pron; // Pronunciation of the lemma if we are at a word end - StateId state; // State in the search tree of this extension - Lm::History lmHistory; // LM history of the hypothesis, possibly extended at a word end - Score score; // Would-be total score of the full hypothesis after extension (incl. LM score) - Score lmScore; // Would-be LM score of a word-end hypothesis after extension - Search::TimeframeIndex timeframe; // Timestamp of `nextToken` for traceback - Nn::LabelScorer::TransitionType transitionType; // Type of transition toward `nextToken` - size_t baseHypIndex; // Index of base hypothesis in global beam + Nn::LabelIndex nextToken; // Proposed token to extend the hypothesis with + const Bliss::LemmaPronunciation* pron; // Pronunciation of the lemma if we are at a word end + StateId state; // State in the search tree of this extension + LanguageModelLookahead::ContextLookaheadReference lookahead; // LM-lookahead table, possibly updated at a word end + Lm::History lmHistory; // LM history of the hypothesis, possibly extended at a word end + Lm::History lookaheadHistory; // LM history used for the lookahead, may be reduced + Lm::History fullLookaheadHistory; // The full/unreduced LM history for the lookahead which will be expanded at a word end + Score score; // Would-be total score of the full hypothesis after extension (incl. LM score) + Score lmScore; // Would-be LM score of a word-end hypothesis after extension + Search::TimeframeIndex timeframe; // Timestamp of `nextToken` for traceback + Nn::LabelScorer::TransitionType transitionType; // Type of transition toward `nextToken` + size_t baseHypIndex; // Index of base hypothesis in global beam bool operator<(ExtensionCandidate const& other) { return score < other.score; @@ -92,12 +100,16 @@ protected: * Struct containing all information about a single hypothesis in the beam */ struct LabelHypothesis { - Nn::ScoringContextRef scoringContext; // Context to compute scores based on this hypothesis - Nn::LabelIndex currentToken; // Most recent token in associated label sequence (useful to infer transition type) - StateId currentState; // Current state in the search tree - Lm::History lmHistory; // Language model history - Score score; // Full score of the hypothesis - Core::Ref trace; // Associated trace for traceback or lattice building of hypothesis + Nn::ScoringContextRef scoringContext; // Context to compute scores based on this hypothesis + Nn::LabelIndex currentToken; // Most recent token in associated label sequence (useful to infer transition type) + StateId currentState; // Current state in the search tree + LanguageModelLookahead::ContextLookaheadReference lookahead; // LM-lookahead table + Lm::History lmHistory; // Language model history + Lm::History lookaheadHistory; // LM history used for the lookahead, may be reduced + Lm::History fullLookaheadHistory; // The full/unreduced LM history for the lookahead + Score score; // Full score of the hypothesis + Score lookaheadScore; // LM-lookahead score + Core::Ref trace; // Associated trace for traceback or lattice building of hypothesis LabelHypothesis(); LabelHypothesis(LabelHypothesis const& base, ExtensionCandidate const& extension, Nn::ScoringContextRef const& newScoringContext); @@ -130,8 +142,14 @@ private: Core::Ref network_; Core::Ref acousticModel_; Core::Ref languageModel_; + Core::Ref lookaheadLm_; Core::Channel debugChannel_; + bool enableLmLookahead_; + bool separateLookaheadLm_; + bool sparseLmLookahead_; + LanguageModelLookahead* lmLookahead_; + // Pre-allocated intermediate vectors std::vector extensions_; std::vector beam_; @@ -187,6 +205,16 @@ private: */ void recombination(std::vector& hypotheses); + /* + * Retrieve the LM lookahead for the given history from cache or compute and cache it if missing + */ + void getLmLookahead(LanguageModelLookahead::ContextLookaheadReference& lookahead, Lm::History history); + + /* + * Compute the sparse or non-sparse LM lookahead score for an extension's state and history, with back-off if needed + */ + Score getLmLookaheadScore(TreeTimesyncBeamSearch::ExtensionCandidate& extension); + /* * Precompute information about the successor structure of each state in the search tree * to avoid repeated computation during the decode steps @@ -206,4 +234,4 @@ private: } // namespace Search -#endif // TREE_TIMESYNC_BEAM_SEARCH_HH +#endif // TREE_TIMESYNC_BEAM_SEARCH_HH \ No newline at end of file From 5b8f64a94c0154057b6d6370545ca4e07a476e9e Mon Sep 17 00:00:00 2001 From: Larissa Date: Mon, 15 Sep 2025 17:35:17 +0200 Subject: [PATCH 08/11] Bugfix --- .../TreeTimesyncBeamSearch.cc | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc index 4de389c31..610c2ccdd 100644 --- a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc +++ b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc @@ -533,12 +533,15 @@ bool TreeTimesyncBeamSearch::decodeStep() { const Bliss::SyntacticTokenSequence sts = lemma->syntacticTokenSequence(); const Bliss::SyntacticToken* st = sts.front(); extension.lmHistory = languageModel_->extendedHistory(extension.lmHistory, st); - Lm::History newLookaheadHistory = lookaheadLm_->extendedHistory(extension.fullLookaheadHistory, st); - if (!(newLookaheadHistory == extension.lookaheadHistory)) { - getLmLookahead(extension.lookahead, newLookaheadHistory); - extension.lookaheadHistory = newLookaheadHistory; - extension.fullLookaheadHistory = newLookaheadHistory; + if (enableLmLookahead_) { + Lm::History newLookaheadHistory = lookaheadLm_->extendedHistory(extension.fullLookaheadHistory, st); + + if (!(newLookaheadHistory == extension.lookaheadHistory)) { + getLmLookahead(extension.lookahead, newLookaheadHistory); + extension.lookaheadHistory = newLookaheadHistory; + extension.fullLookaheadHistory = newLookaheadHistory; + } } } From 8a46c0e4fdbf95e894f997c61ad3e86445a6d811 Mon Sep 17 00:00:00 2001 From: Larissa Date: Tue, 4 Nov 2025 16:27:34 +0100 Subject: [PATCH 09/11] Add logging and fix comments --- .../TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc index dcd545b57..2f41a64ba 100644 --- a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc +++ b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc @@ -686,6 +686,10 @@ void TreeTimesyncBeamSearch::logStatistics() const { numWordEndHypsAfterBeamPruning_.write(clog()); numActiveHyps_.write(clog()); numActiveTrees_.write(clog()); + + if (enableLmLookahead_) { + lmLookahead_->logStatistics(); + } } Nn::LabelScorer::TransitionType TreeTimesyncBeamSearch::inferTransitionType(Nn::LabelIndex prevLabel, Nn::LabelIndex nextLabel) const { @@ -829,11 +833,11 @@ Score TreeTimesyncBeamSearch::getLmLookaheadScore(TreeTimesyncBeamSearch::Extens Score lookaheadScore = 0; bool scoreFound = false; do { - if (extension.lookahead->isSparse()) { // Non-sparse lookahead + if (extension.lookahead->isSparse()) { // Sparse lookahead auto lookaheadHash = lmLookahead_->lookaheadHash(extension.state); scoreFound = extension.lookahead->getScoreForLookAheadHashSparse(lookaheadHash, lookaheadScore); } - else { // Sparse lookahead + else { // Non-sparse lookahead auto lookaheadId = lmLookahead_->lookaheadId(extension.state); lookaheadScore = extension.lookahead->scoreForLookAheadIdNormal(lookaheadId); scoreFound = true; From b89544cef900437bbdafd454ad1a473522eb42f8 Mon Sep 17 00:00:00 2001 From: Larissa Date: Tue, 4 Nov 2025 16:50:22 +0100 Subject: [PATCH 10/11] Formatting --- src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc index 2f41a64ba..983820bd3 100644 --- a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc +++ b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc @@ -569,7 +569,7 @@ bool TreeTimesyncBeamSearch::decodeStep() { const Bliss::SyntacticToken* st = sts.front(); extension.lmHistory = languageModel_->extendedHistory(extension.lmHistory, st); - if (enableLmLookahead_) { + if (enableLmLookahead_) { Lm::History newLookaheadHistory = lookaheadLm_->extendedHistory(extension.fullLookaheadHistory, st); if (!(newLookaheadHistory == extension.lookaheadHistory)) { @@ -578,7 +578,6 @@ bool TreeTimesyncBeamSearch::decodeStep() { extension.fullLookaheadHistory = newLookaheadHistory; } } - } auto const& baseHyp = newBeam_[extension.baseHypIndex]; From 23ffb63a52dfaf28f408b5bdab94e252ecb22de9 Mon Sep 17 00:00:00 2001 From: Larissa Date: Mon, 17 Nov 2025 09:42:15 +0100 Subject: [PATCH 11/11] Solve merge conflicts --- .../TreeTimesyncBeamSearch.cc | 146 +++++++++++++++++- .../TreeTimesyncBeamSearch.hh | 44 ++++-- 2 files changed, 175 insertions(+), 15 deletions(-) diff --git a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc index 88b9ef02b..9b5abdaea 100644 --- a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc +++ b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc @@ -21,6 +21,8 @@ #include #include #include +#include +#include #include #include #include @@ -39,8 +41,12 @@ TreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis() : scoringContext(), currentToken(Nn::invalidLabelIndex), currentState(invalidTreeNodeIndex), + lookahead(), lmHistory(), + lookaheadHistory(), + fullLookaheadHistory(), score(0.0), + lookaheadScore(0.0), trace(Core::ref(new LatticeTrace(0, {0, 0}, {}))) {} TreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( @@ -50,22 +56,32 @@ TreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( : scoringContext(newScoringContext), currentToken(extension.nextToken), currentState(extension.nextState), + lookahead(base.lookahead), lmHistory(base.lmHistory), + lookaheadHistory(base.lookaheadHistory), + fullLookaheadHistory(base.fullLookaheadHistory), timeframe(extension.timeframe), score(extension.score), + lookaheadScore(extension.lookaheadScore), trace(base.trace) { } TreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( LabelHypothesis const& base, TreeTimesyncBeamSearch::WordEndExtensionCandidate const& extension, - Lm::History const& newLmHistory) + Lm::History const& newLmHistory, + LanguageModelLookahead::ContextLookaheadReference const newLookahead, + Lm::History const& newLookaheadHistory) : scoringContext(base.scoringContext), currentToken(base.currentToken), currentState(extension.rootState), + lookahead(newLookahead), lmHistory(newLmHistory), + lookaheadHistory(newLookaheadHistory), + fullLookaheadHistory(base.fullLookaheadHistory), timeframe(base.timeframe), - score(extension.score) { + score(extension.score), + lookaheadScore(0.0) { auto newLmScore = score - base.score; auto totalLmScore = base.trace->score.lm + newLmScore; auto totalAmScore = score - totalLmScore; @@ -125,6 +141,21 @@ const Core::ParameterBool TreeTimesyncBeamSearch::paramCollapseRepeatedLabels( "Collapse repeated emission of the same label into one output. If false, every emission is treated like a new output.", false); +const Core::ParameterBool TreeTimesyncBeamSearch::paramLmLookahead( + "lm-lookahead", + "Enable language model lookahead.", + false); + +const Core::ParameterBool TreeTimesyncBeamSearch::paramSeparateLookaheadLm( + "separate-lookahead-lm", + "Use a separate LM for lookahead.", + false); + +const Core::ParameterBool TreeTimesyncBeamSearch::paramSparseLmLookAhead( + "sparse-lm-lookahead", + "Use sparse n-gram LM lookahead.", + true); + const Core::ParameterBool TreeTimesyncBeamSearch::paramSentenceEndFallBack( "sentence-end-fall-back", "Allow for fallback solution if no active word-end hypothesis exists at the end of a segment.", @@ -150,6 +181,9 @@ TreeTimesyncBeamSearch::TreeTimesyncBeamSearch(Core::Configuration const& config cacheCleanupInterval_(paramCacheCleanupInterval(config)), useBlank_(), collapseRepeatedLabels_(paramCollapseRepeatedLabels(config)), + enableLmLookahead_(paramLmLookahead(config)), + separateLookaheadLm_(paramSeparateLookaheadLm(config)), + sparseLmLookahead_(paramSparseLmLookAhead(config)), sentenceEndFallback_(paramSentenceEndFallBack(config)), logStepwiseStatistics_(paramLogStepwiseStatistics(config)), labelScorer_(), @@ -232,6 +266,34 @@ bool TreeTimesyncBeamSearch::setModelCombination(Speech::ModelCombination const& // Create look-ups for state successors and exits of each state createSuccessorLookups(); + // Set lookahead LM + if (enableLmLookahead_) { + if (separateLookaheadLm_) { + log() << "Use separate lookahead LM"; + lookaheadLm_ = Lm::Module::instance().createScaledLanguageModel(select("lm-lookahead"), lexicon_); + } + else if (languageModel_->lookaheadLanguageModel().get() != nullptr) { + lookaheadLm_ = Core::Ref(new Lm::LanguageModelScaling(select("lookahead-lm"), + Core::Ref(const_cast(languageModel_->lookaheadLanguageModel().get())))); + } + else { + lookaheadLm_ = languageModel_; + } + + if (sparseLmLookahead_ && !dynamic_cast(lookaheadLm_->unscaled().get())) { + warning() << "Not using sparse LM lookahead, because the LM is not a backing-off LM."; + sparseLmLookahead_ = false; + } + + lmLookahead_ = new LanguageModelLookahead(Core::Configuration(config, "lm-lookahead"), + modelCombination.pronunciationScale(), + lookaheadLm_, + network_->structure, + network_->rootState, + network_->exits, + acousticModel_); + } + reset(); // Create global cache @@ -257,6 +319,11 @@ void TreeTimesyncBeamSearch::reset() { beam_.front().currentState = network_->rootState; beam_.front().lmHistory = languageModel_->startHistory(); + if (enableLmLookahead_) { + beam_.front().lookaheadHistory = lookaheadLm_->startHistory(); + beam_.front().fullLookaheadHistory = lookaheadLm_->startHistory(); + } + currentSearchStep_ = 0ul; finishedSegment_ = false; @@ -366,6 +433,7 @@ bool TreeTimesyncBeamSearch::decodeStep() { successorState, 0, hyp.score, + 0, transitionType, hypIndex}); requests_.push_back({beam_[hypIndex].scoringContext, tokenIdx, transitionType}); @@ -387,6 +455,14 @@ bool TreeTimesyncBeamSearch::decodeStep() { for (size_t requestIdx = 0ul; requestIdx < withinWordExtensions_.size(); ++requestIdx) { withinWordExtensions_[requestIdx].score += result->scores[requestIdx]; withinWordExtensions_[requestIdx].timeframe = result->timeframes[requestIdx]; + + // Add the LM lookahead score to the extensions' scores for pruning + // Make sure not to calculate the lookahead score for the blank lemma which is reachable from the root + if (enableLmLookahead_ and not(beam_[withinWordExtensions_[requestIdx].baseHypIndex].currentState == network_->rootState and withinWordExtensions_[requestIdx].nextToken == blankLabelIndex_)) { + Score lookaheadScore = getLmLookaheadScore(withinWordExtensions_[requestIdx]); + withinWordExtensions_[requestIdx].lookaheadScore = lookaheadScore; + withinWordExtensions_[requestIdx].score += lookaheadScore; + } } if (logStepwiseStatistics_) { @@ -404,7 +480,7 @@ bool TreeTimesyncBeamSearch::decodeStep() { // Create new label hypotheses from extension candidates newBeam_.clear(); - for (auto const& extension : withinWordExtensions_) { + for (auto extension : withinWordExtensions_) { auto const& baseHyp = beam_[extension.baseHypIndex]; auto newScoringContext = labelScorer_->extendedScoringContext( @@ -436,6 +512,12 @@ bool TreeTimesyncBeamSearch::decodeStep() { for (size_t hypIndex = 0ul; hypIndex < newBeam_.size(); ++hypIndex) { auto& hyp = newBeam_[hypIndex]; + if (enableLmLookahead_) { + // Subtract the LM lookahead score again + hyp.score -= hyp.lookaheadScore; + hyp.lookaheadScore = 0.0; + } + auto const& exitList = exitLookup_[hyp.currentState]; if (not exitList.empty()) { // Create one word-end hypothesis for each exit @@ -464,7 +546,7 @@ bool TreeTimesyncBeamSearch::decodeStep() { clog() << Core::XmlFull("num-word-end-hyps-after-score-pruning", wordEndExtensions_.size()); } - // Create new word-end label hypotheses from word-end extension candidates and update the LM history + // Create new word-end label hypotheses from word-end extension candidates, update the LM history and prepare the new lookahead if its history has changed wordEndHypotheses_.clear(); for (auto& extension : wordEndExtensions_) { auto const& baseHyp = newBeam_[extension.baseHypIndex]; @@ -472,13 +554,24 @@ bool TreeTimesyncBeamSearch::decodeStep() { auto newLmHistory = baseHyp.lmHistory; auto const& sts = extension.pron->lemma()->syntacticTokenSequence(); + LanguageModelLookahead::ContextLookaheadReference newLookahead = baseHyp.lookahead; + Lm::History newLookaheadHistory = baseHyp.fullLookaheadHistory; + if (sts.size() != 0) { require(sts.size() == 1); const Bliss::SyntacticToken* st = sts.front(); newLmHistory = languageModel_->extendedHistory(newLmHistory, st); + + if (enableLmLookahead_) { + newLookaheadHistory = lookaheadLm_->extendedHistory(baseHyp.fullLookaheadHistory, st); + + if (!(newLookaheadHistory == baseHyp.lookaheadHistory)) { + getLmLookahead(newLookahead, newLookaheadHistory); + } + } } - wordEndHypotheses_.push_back({baseHyp, extension, newLmHistory}); + wordEndHypotheses_.push_back({baseHyp, extension, newLmHistory, newLookahead, newLookaheadHistory}); } recombination(wordEndHypotheses_, true); @@ -582,6 +675,10 @@ void TreeTimesyncBeamSearch::logStatistics() const { numWordEndHypsAfterBeamPruning_.write(clog()); numActiveHyps_.write(clog()); numActiveTrees_.write(clog()); + + if (enableLmLookahead_) { + lmLookahead_->logStatistics(); + } } Nn::LabelScorer::TransitionType TreeTimesyncBeamSearch::inferTransitionType(Nn::LabelIndex prevLabel, Nn::LabelIndex nextLabel) const { @@ -732,6 +829,43 @@ void TreeTimesyncBeamSearch::createSuccessorLookups() { } } +void TreeTimesyncBeamSearch::getLmLookahead(LanguageModelLookahead::ContextLookaheadReference& lookahead, Lm::History history) { + lookahead = lmLookahead_->getLookahead(history); + lmLookahead_->fill(lookahead, sparseLmLookahead_); +} + +Score TreeTimesyncBeamSearch::getLmLookaheadScore(TreeTimesyncBeamSearch::WithinWordExtensionCandidate& extension) { + auto& baseHyp = beam_[extension.baseHypIndex]; + + if (!baseHyp.lookahead) { + getLmLookahead(baseHyp.lookahead, baseHyp.lookaheadHistory); + } + + Score lookaheadScore = 0; + bool scoreFound = false; + do { + if (baseHyp.lookahead->isSparse()) { // Sparse lookahead + auto lookaheadHash = lmLookahead_->lookaheadHash(extension.nextState); + scoreFound = baseHyp.lookahead->getScoreForLookAheadHashSparse(lookaheadHash, lookaheadScore); + } + else { // Non-sparse lookahead + auto lookaheadId = lmLookahead_->lookaheadId(extension.nextState); + lookaheadScore = baseHyp.lookahead->scoreForLookAheadIdNormal(lookaheadId); + scoreFound = true; + } + + if (!scoreFound) { // No lookahead table entry, use back-off + const Lm::BackingOffLm* lm = dynamic_cast(lookaheadLm_->unscaled().get()); + lookaheadScore += lm->getBackOffScore(baseHyp.lookaheadHistory); + // Reduce the history and retrieve the corresponding lookahead table + baseHyp.lookaheadHistory = lm->reducedHistory(baseHyp.lookaheadHistory, lm->historyLength(baseHyp.lookaheadHistory) - 1); + getLmLookahead(baseHyp.lookahead, baseHyp.lookaheadHistory); + } + } while (!scoreFound); + + return lookaheadScore; +} + void TreeTimesyncBeamSearch::finalizeLmScoring() { newBeam_.clear(); for (size_t hypIndex = 0ul; hypIndex < beam_.size(); ++hypIndex) { @@ -769,4 +903,4 @@ void TreeTimesyncBeamSearch::finalizeLmScoring() { beam_.swap(newBeam_); } -} // namespace Search +} // namespace Search \ No newline at end of file diff --git a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh index 28288958b..8a9523aef 100644 --- a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh +++ b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -33,6 +34,7 @@ namespace Search { * Simple time synchronous beam search algorithm on a search tree built by a TreeBuilder. * At a word end, a language model score is added to the hypothesis score, * if no language model should be used, the LM-scale has to be set to 0.0. + * Full or sparse language model lookahead can optionally be used with the same or with a separate LM. * Supports global or separate pruning of within-word and word-end hypotheses * by max beam-size and by score difference to the best hypothesis. * Uses a LabelScorer to context initialization/extension and scoring. @@ -48,6 +50,9 @@ public: static const Core::ParameterFloat paramScoreThreshold; static const Core::ParameterFloat paramWordEndScoreThreshold; static const Core::ParameterBool paramCollapseRepeatedLabels; + static const Core::ParameterBool paramLmLookahead; + static const Core::ParameterBool paramSeparateLookaheadLm; + static const Core::ParameterBool paramSparseLmLookAhead; static const Core::ParameterBool paramSentenceEndFallBack; static const Core::ParameterBool paramLogStepwiseStatistics; static const Core::ParameterBool paramCacheCleanupInterval; @@ -81,6 +86,7 @@ protected: StateId nextState; // State in the search tree of this extension Search::TimeframeIndex timeframe; // Timestamp of `nextToken` for traceback Score score; // Would-be total score of the full hypothesis after extension + Score lookaheadScore; // LM-lookahead score Nn::LabelScorer::TransitionType transitionType; // Type of transition toward `nextToken` size_t baseHypIndex; // Index of base hypothesis in beam @@ -104,13 +110,17 @@ protected: * Struct containing all information about a single hypothesis in the beam */ struct LabelHypothesis { - Nn::ScoringContextRef scoringContext; // Context to compute scores based on this hypothesis - Nn::LabelIndex currentToken; // Most recent token in associated label sequence (useful to infer transition type) - StateId currentState; // Current state in the search tree - Lm::History lmHistory; // Language model history - Speech::TimeframeIndex timeframe; // Timeframe of current token - Score score; // Full score of the hypothesis - Core::Ref trace; // Associated trace for traceback or lattice building of hypothesis + Nn::ScoringContextRef scoringContext; // Context to compute scores based on this hypothesis + Nn::LabelIndex currentToken; // Most recent token in associated label sequence (useful to infer transition type) + StateId currentState; // Current state in the search tree + LanguageModelLookahead::ContextLookaheadReference lookahead; // LM-lookahead table + Lm::History lmHistory; // Language model history + Lm::History lookaheadHistory; // LM history used for the lookahead, may be reduced + Lm::History fullLookaheadHistory; // The full/unreduced LM history for the lookahead + Speech::TimeframeIndex timeframe; // Timeframe of current token + Score score; // Full score of the hypothesis + Score lookaheadScore; // LM-lookahead score + Core::Ref trace; // Associated trace for traceback or lattice building of hypothesis LabelHypothesis(); @@ -118,7 +128,7 @@ protected: LabelHypothesis(LabelHypothesis const& base, WithinWordExtensionCandidate const& extension, Nn::ScoringContextRef const& newScoringContext); // Word-end constructor from base and word-end extension - LabelHypothesis(LabelHypothesis const& base, WordEndExtensionCandidate const& extension, Lm::History const& newLmHistory); + LabelHypothesis(LabelHypothesis const& base, WordEndExtensionCandidate const& extension, Lm::History const& newLmHistory, LanguageModelLookahead::ContextLookaheadReference const newLookahead, Lm::History const& newLookaheadHistory); bool operator<(LabelHypothesis const& other) const { return score < other.score; @@ -148,8 +158,14 @@ private: Core::Ref network_; Core::Ref acousticModel_; Core::Ref languageModel_; + Core::Ref lookaheadLm_; Core::Channel debugChannel_; + bool enableLmLookahead_; + bool separateLookaheadLm_; + bool sparseLmLookahead_; + LanguageModelLookahead* lmLookahead_; + // Pre-allocated intermediate vectors std::vector withinWordExtensions_; std::vector wordEndExtensions_; @@ -208,6 +224,16 @@ private: */ void recombination(std::vector& hypotheses, bool createTraceSiblings); + /* + * Retrieve or compute the LM lookahead for the given history + */ + void getLmLookahead(LanguageModelLookahead::ContextLookaheadReference& lookahead, Lm::History history); + + /* + * Compute the sparse or non-sparse LM lookahead score for an extension's state and history, with back-off if needed + */ + Score getLmLookaheadScore(TreeTimesyncBeamSearch::WithinWordExtensionCandidate& extension); + /* * Precompute information about the successor structure of each state in the search tree * to avoid repeated computation during the decode steps @@ -227,4 +253,4 @@ private: } // namespace Search -#endif // TREE_TIMESYNC_BEAM_SEARCH_HH +#endif // TREE_TIMESYNC_BEAM_SEARCH_HH \ No newline at end of file