From 1e7035ebc7539177718632179a017f2cc6cca872 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Wed, 19 Feb 2025 19:10:45 +0100 Subject: [PATCH 01/52] Implement simple lexiconfree time-sync beam search --- Modules.make | 1 + .../makefiles/Modules.make | 1 + .../makefiles/Modules.make | 1 + .../makefiles/Modules.make | 1 + .../makefiles/Modules.make | 1 + .../makefiles/Modules.make | 1 + .../LexiconfreeTimesyncBeamSearch.cc | 442 ++++++++++++++++++ .../LexiconfreeTimesyncBeamSearch.hh | 170 +++++++ .../LexiconfreeTimesyncBeamSearch/Makefile | 24 + src/Search/Makefile | 4 + src/Search/Traceback.hh | 3 +- 11 files changed, 648 insertions(+), 1 deletion(-) create mode 100644 src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc create mode 100644 src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh create mode 100644 src/Search/LexiconfreeTimesyncBeamSearch/Makefile diff --git a/Modules.make b/Modules.make index a9ee0ae7..4ae7efce 100644 --- a/Modules.make +++ b/Modules.make @@ -148,6 +148,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) +LIBS_SEARCH += src/Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/apptainer/2022-10-21_tensorflow-1.15_arm_v1/makefiles/Modules.make b/apptainer/2022-10-21_tensorflow-1.15_arm_v1/makefiles/Modules.make index f171381f..f427cace 100644 --- a/apptainer/2022-10-21_tensorflow-1.15_arm_v1/makefiles/Modules.make +++ b/apptainer/2022-10-21_tensorflow-1.15_arm_v1/makefiles/Modules.make @@ -143,6 +143,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) +LIBS_SEARCH += src/Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/apptainer/2022-10-21_tensorflow-1.15_v1/makefiles/Modules.make b/apptainer/2022-10-21_tensorflow-1.15_v1/makefiles/Modules.make index f171381f..f427cace 100644 --- a/apptainer/2022-10-21_tensorflow-1.15_v1/makefiles/Modules.make +++ b/apptainer/2022-10-21_tensorflow-1.15_v1/makefiles/Modules.make @@ -143,6 +143,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) +LIBS_SEARCH += src/Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/apptainer/2023-05-08_tensorflow-2.8_v1/makefiles/Modules.make b/apptainer/2023-05-08_tensorflow-2.8_v1/makefiles/Modules.make index 2ea9bf10..34a90293 100644 --- a/apptainer/2023-05-08_tensorflow-2.8_v1/makefiles/Modules.make +++ b/apptainer/2023-05-08_tensorflow-2.8_v1/makefiles/Modules.make @@ -143,6 +143,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) +LIBS_SEARCH += src/Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/apptainer/2023-08-09_tensorflow-2.8_onnx-1.15_v1/makefiles/Modules.make b/apptainer/2023-08-09_tensorflow-2.8_onnx-1.15_v1/makefiles/Modules.make index af199b57..0a597c8c 100644 --- a/apptainer/2023-08-09_tensorflow-2.8_onnx-1.15_v1/makefiles/Modules.make +++ b/apptainer/2023-08-09_tensorflow-2.8_onnx-1.15_v1/makefiles/Modules.make @@ -147,6 +147,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) +LIBS_SEARCH += src/Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/apptainer/2023-11-08_tensorflow-2.14_v1/makefiles/Modules.make b/apptainer/2023-11-08_tensorflow-2.14_v1/makefiles/Modules.make index bc36c260..1daa8d5d 100644 --- a/apptainer/2023-11-08_tensorflow-2.14_v1/makefiles/Modules.make +++ b/apptainer/2023-11-08_tensorflow-2.14_v1/makefiles/Modules.make @@ -148,6 +148,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) +LIBS_SEARCH += src/Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc new file mode 100644 index 00000000..80148ab7 --- /dev/null +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc @@ -0,0 +1,442 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "LexiconfreeTimesyncBeamSearch.hh" +#include +#include +#include +#include +#include +#include + +namespace Search { + +/* + * ===================================== + * === LexiconfreeTimesyncBeamSearch === + * ===================================== + */ + +const Core::ParameterInt LexiconfreeTimesyncBeamSearch::paramMaxBeamSize( + "max-beam-size", + "Maximum number of elements in the search beam.", + 1, 1); + +const Core::ParameterFloat LexiconfreeTimesyncBeamSearch::paramScoreThreshold( + "score-threshold", + "Prune any hypotheses with a score that is at least this much worse than the best hypothesis. If not set, no score pruning will be done.", + Core::Type::max, 0); + +const Core::ParameterInt LexiconfreeTimesyncBeamSearch::paramBlankLabelIndex( + "blank-label-index", + "Index of the blank label in the lexicon. If not set, the search will not use blank.", + Core::Type::max); + +const Core::ParameterBool LexiconfreeTimesyncBeamSearch::paramAllowLabelLoop( + "allow-label-loop", + "Collapse repeated emission of the same label into one output. If false, every emission is treated like a new output.", + false); + +const Core::ParameterBool LexiconfreeTimesyncBeamSearch::paramLogStepwiseStatistics( + "log-stepwise-statistics", + "Log statistics about the beam at every search step.", + false); + +const Core::ParameterBool LexiconfreeTimesyncBeamSearch::paramDebugLogging( + "debug-logging", + "Enable detailed logging for debugging purposes.", + false); + +LexiconfreeTimesyncBeamSearch::LexiconfreeTimesyncBeamSearch(Core::Configuration const& config) + : Core::Component(config), + SearchAlgorithmV2(config), + maxBeamSize_(paramMaxBeamSize(config)), + scoreThreshold_(paramScoreThreshold(config)), + blankLabelIndex_(paramBlankLabelIndex(config)), + allowLabelLoop_(paramAllowLabelLoop(config)), + logStepwiseStatistics_(paramLogStepwiseStatistics(config)), + debugLogging_(paramDebugLogging(config)), + labelScorer_(), + beam_(), + initializationTime_(), + featureProcessingTime_(), + scoringTime_(), + contextExtensionTime_() { + beam_.reserve(maxBeamSize_); + useBlank_ = blankLabelIndex_ != Core::Type::max; + useScorePruning_ = scoreThreshold_ != Core::Type::max; +} + +void LexiconfreeTimesyncBeamSearch::reset() { + initializationTime_.tic(); + + labelScorer_->reset(); + + // Reset beam to a single empty hypothesis + beam_.clear(); + beam_.push_back(LabelHypothesis()); + beam_.front().scoringContext = labelScorer_->getInitialScoringContext(); + + initializationTime_.toc(); +} + +Speech::ModelCombination::Mode LexiconfreeTimesyncBeamSearch::requiredModelCombination() const { + return Speech::ModelCombination::useLabelScorer | Speech::ModelCombination::useLexicon; +} + +bool LexiconfreeTimesyncBeamSearch::setModelCombination(Speech::ModelCombination const& modelCombination) { + lexicon_ = modelCombination.lexicon(); + labelScorer_ = modelCombination.labelScorer(); + + reset(); + return true; +} + +void LexiconfreeTimesyncBeamSearch::enterSegment(Bliss::SpeechSegment const* segment) { + initializationTime_.tic(); + labelScorer_->reset(); + resetStatistics(); + initializationTime_.toc(); +} + +void LexiconfreeTimesyncBeamSearch::finishSegment() { + featureProcessingTime_.tic(); + labelScorer_->signalNoMoreFeatures(); + featureProcessingTime_.toc(); + decodeManySteps(); + logStatistics(); +} + +void LexiconfreeTimesyncBeamSearch::putFeature(std::shared_ptr const& data, size_t featureSize) { + featureProcessingTime_.tic(); + labelScorer_->addInput(data, featureSize); + featureProcessingTime_.toc(); +} + +void LexiconfreeTimesyncBeamSearch::putFeature(std::vector const& data) { + featureProcessingTime_.tic(); + labelScorer_->addInput(data); + featureProcessingTime_.toc(); +} + +void LexiconfreeTimesyncBeamSearch::putFeatures(std::shared_ptr const& data, size_t timeSize, size_t featureSize) { + featureProcessingTime_.tic(); + labelScorer_->addInputs(data, timeSize, featureSize); + featureProcessingTime_.toc(); +} + +Core::Ref LexiconfreeTimesyncBeamSearch::getCurrentBestTraceback() const { + return Core::ref(new Traceback(beam_.front().traceback)); +} + +Core::Ref LexiconfreeTimesyncBeamSearch::getCurrentBestWordLattice() const { + if (beam_.front().traceback.empty()) { + return Core::ref(new Lattice::WordLatticeAdaptor()); + } + + // use default LemmaAlphabet mode of StandardWordLattice + Core::Ref result(new Lattice::StandardWordLattice(lexicon_)); + Core::Ref wordBoundaries(new Lattice::WordBoundaries); + + // create a linear lattice from the traceback + Fsa::State* currentState = result->initialState(); + for (auto it = beam_.front().traceback.begin(); it != beam_.front().traceback.end(); ++it) { + wordBoundaries->set(currentState->id(), Lattice::WordBoundary(it->time)); + Fsa::State* nextState; + if (std::next(it) == beam_.front().traceback.end()) { + nextState = result->finalState(); + } + else { + nextState = result->newState(); + } + ScoreVector scores = it->score; + if (it != beam_.front().traceback.begin()) { + scores -= std::prev(it)->score; + } + result->newArc(currentState, nextState, it->pronunciation->lemma(), scores.acoustic, scores.lm); + currentState = nextState; + } + + result->setWordBoundaries(wordBoundaries); + result->addAcyclicProperty(); + + return Core::ref(new Lattice::WordLatticeAdaptor(result)); +} + +void LexiconfreeTimesyncBeamSearch::resetStatistics() { + initializationTime_.reset(); + featureProcessingTime_.reset(); + scoringTime_.reset(); + contextExtensionTime_.reset(); +} + +void LexiconfreeTimesyncBeamSearch::logStatistics() const { + clog() << Core::XmlOpen("timing-statistics") + Core::XmlAttribute("unit", "milliseconds"); + clog() << Core::XmlOpen("initialization-time") << initializationTime_.getTotalMilliseconds() << Core::XmlClose("initialization-time"); + clog() << Core::XmlOpen("feature-processing-time") << featureProcessingTime_.getTotalMilliseconds() << Core::XmlClose("feature-processing-time"); + clog() << Core::XmlOpen("scoring-time") << scoringTime_.getTotalMilliseconds() << Core::XmlClose("scoring-time"); + clog() << Core::XmlOpen("context-extension-time") << contextExtensionTime_.getTotalMilliseconds() << Core::XmlClose("context-extension-time"); + clog() << Core::XmlClose("timing-statistics"); +} + +Nn::LabelScorer::TransitionType LexiconfreeTimesyncBeamSearch::inferTransitionType(Nn::LabelIndex prevLabel, Nn::LabelIndex nextLabel) const { + // These checks will result in false if `blankLabelIndex_` is still `Core::Type::max`, i.e. no blank is used + bool prevIsBlank = (prevLabel == blankLabelIndex_); + bool nextIsBlank = (nextLabel == blankLabelIndex_); + + if (prevIsBlank) { + if (nextIsBlank) { + return Nn::LabelScorer::TransitionType::BLANK_LOOP; + } + else { + return Nn::LabelScorer::TransitionType::BLANK_TO_LABEL; + } + } + else { + if (nextIsBlank) { + return Nn::LabelScorer::TransitionType::LABEL_TO_BLANK; + } + else if (allowLabelLoop_ and prevLabel == nextLabel) { + return Nn::LabelScorer::TransitionType::LABEL_LOOP; + } + else { + return Nn::LabelScorer::TransitionType::LABEL_TO_LABEL; + } + } +} + +void LexiconfreeTimesyncBeamSearch::beamPruning(std::vector& extensions) const { + if (extensions.size() <= maxBeamSize_) { + return; + } + + // Sort the hypotheses by associated score value such that the first `beamSize_` elements are the best and sorted + std::nth_element(extensions.begin(), extensions.begin() + maxBeamSize_, extensions.end()); + extensions.resize(maxBeamSize_); // Get rid of excessive elements +} + +void LexiconfreeTimesyncBeamSearch::scorePruning(std::vector& extensions) const { + // Compute the pruning threshold + auto pruningThreshold = extensions.front().score + scoreThreshold_; + size_t numSurvivingHyps = 0ul; + // Use the fact that hypotheses are sorted by corresponding score and prune all indices after the first one that + // violates the score threshold + for (auto const& ext : extensions) { + if (ext.score > pruningThreshold) { + break; + } + ++numSurvivingHyps; + } + extensions.resize(numSurvivingHyps); // Resize the hypotheses to keep only the surviving items +} + +void LexiconfreeTimesyncBeamSearch::recombination(std::vector& hypotheses) { + std::vector recombinedHypotheses; + + std::unordered_set seenScoringContexts; + for (auto const& hyp : hypotheses) { + if (seenScoringContexts.find(hyp.scoringContext) == seenScoringContexts.end()) { + recombinedHypotheses.push_back(hyp); + seenScoringContexts.insert(hyp.scoringContext); + } + } + hypotheses.swap(recombinedHypotheses); +} + +bool LexiconfreeTimesyncBeamSearch::decodeStep() { + // Assume the output labels are stored as lexicon lemma orth and ordered consistently with NN output index + auto lemmas = lexicon_->lemmas(); + + /* + * Collect all possible extensions for all hypotheses in the beam. + */ + std::vector extensions; + extensions.reserve(beam_.size() * lexicon_->nLemmas()); + + for (size_t hypIndex = 0ul; hypIndex < beam_.size(); ++hypIndex) { + auto& hyp = beam_[hypIndex]; + + // Iterate over possible successors (all lemmas) + for (auto lemmaIt = lemmas.first; lemmaIt != lemmas.second; ++lemmaIt) { + const Bliss::Lemma* lemma(*lemmaIt); + Nn::LabelIndex tokenIdx = lemma->id(); + + extensions.push_back( + {tokenIdx, + lemma->pronunciations().first, + hyp.score, + 0, + inferTransitionType(hyp.currentToken, tokenIdx), + hypIndex}); + } + } + + /* + * Create scoring requests for the label scorer. + * Each extension candidate makes up a request. + */ + std::vector requests; + requests.reserve(extensions.size()); + for (const auto& extension : extensions) { + requests.push_back({beam_[extension.baseHypIndex].scoringContext, extension.nextToken, extension.transitionType}); + } + + /* + * Perform scoring of all the requests with the label scorer. + */ + scoringTime_.tic(); + auto result = labelScorer_->computeScoresWithTimes(requests); + scoringTime_.toc(); + + if (not result) { + // LabelScorer could not compute scores -> no search step can be made. + return false; + } + + for (size_t requestIdx = 0ul; requestIdx < extensions.size(); ++requestIdx) { + extensions[requestIdx].score += result->scores[requestIdx]; + extensions[requestIdx].timeframe = result->timeframes[requestIdx]; + } + + /* + * Prune set of possible extensions by max beam size and possibly also by score. + */ + beamPruning(extensions); + if (debugLogging_) { + log() << extensions.size() << " candidates survived beam pruning"; + } + + std::sort(extensions.begin(), extensions.end()); + + if (useScorePruning_) { + // Extensions are sorted by score after `beamPruning`. + scorePruning(extensions); + + if (debugLogging_) { + log() << extensions.size() << " candidates survived score pruning"; + } + } + + /* + * Create new beam from surviving extensions. + */ + std::vector newBeam; + newBeam.reserve(extensions.size()); + + for (auto const& extension : extensions) { + auto const& baseHyp = beam_[extension.baseHypIndex]; + auto newScoringContext = labelScorer_->extendedScoringContext({baseHyp.scoringContext, extension.nextToken, extension.transitionType}); + newBeam.push_back({baseHyp, extension, newScoringContext}); + } + + /* + * For all hypotheses with the same scoring context keep only the best since they will + * all develop in the same way. + */ + recombination(newBeam); + if (debugLogging_) { + log() << newBeam.size() << " hypotheses after recombination"; + + std::stringstream ss; + for (size_t hypIdx = 0ul; hypIdx < newBeam.size(); ++hypIdx) { + ss << "Hypothesis " << hypIdx + 1ul << ": " << newBeam[hypIdx].toString() << "\n"; + } + log() << ss.str(); + } + + beam_.swap(newBeam); + + if (logStepwiseStatistics_) { + clog() << Core::XmlOpen("search-step-stats"); + clog() << Core::XmlOpen("active-hyps") << beam_.size() << Core::XmlClose("active-hyps"); + clog() << Core::XmlOpen("best-hyp-score") << beam_.front().score << Core::XmlClose("best-hyp-score"); + clog() << Core::XmlOpen("worst-hyp-score") << beam_.back().score << Core::XmlClose("worst-hyp-score"); + clog() << Core::XmlClose("search-step-stats"); + } + + return true; +} + +/* + * ======================= + * === LabelHypothesis === + * ======================= + */ + +LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis() + : scoringContext(), + currentToken(Core::Type::max), + score(0.0), + traceback() {} + +LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( + LexiconfreeTimesyncBeamSearch::LabelHypothesis const& base, + LexiconfreeTimesyncBeamSearch::ExtensionCandidate const& extension, + Nn::ScoringContextRef const& newScoringContext) + : scoringContext(newScoringContext), + currentToken(extension.nextToken), + score(extension.score), + traceback(base.traceback) { + switch (extension.transitionType) { + case Nn::LabelScorer::LABEL_TO_LABEL: + case Nn::LabelScorer::LABEL_TO_BLANK: + case Nn::LabelScorer::BLANK_TO_LABEL: + this->traceback.push_back(TracebackItem(extension.pron, extension.timeframe + 1u, ScoreVector(extension.score, 0.0), {})); + break; + case Nn::LabelScorer::LABEL_LOOP: + case Nn::LabelScorer::BLANK_LOOP: + if (not this->traceback.empty()) { + this->traceback.back().score.acoustic = extension.score; + this->traceback.back().time = extension.timeframe + 1u; + } + break; + } +} + +std::string LexiconfreeTimesyncBeamSearch::LabelHypothesis::toString() const { + std::stringstream ss; + ss << "Score: " << score << ", label sequence: "; + for (auto& item : traceback) { + if (item.pronunciation != nullptr) { + ss << item.pronunciation->lemma()->symbol() << " "; + } + } + return ss.str(); +} + +/* + * ======================= + * === TimeStatistic ===== + * ======================= + */ + +void LexiconfreeTimesyncBeamSearch::TimeStatistic::reset() { + total = 0.0; +} + +void LexiconfreeTimesyncBeamSearch::TimeStatistic::tic() { + startTime = std::chrono::steady_clock::now(); +} + +void LexiconfreeTimesyncBeamSearch::TimeStatistic::toc() { + auto endTime = std::chrono::steady_clock::now(); + // Duration in milliseconds + total += std::chrono::duration_cast>>(endTime - startTime).count(); +} + +double LexiconfreeTimesyncBeamSearch::TimeStatistic::getTotalMilliseconds() const { + return total; +} +} // namespace Search diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh new file mode 100644 index 00000000..0eb79d1a --- /dev/null +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh @@ -0,0 +1,170 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LEXICONFREE_TIMESYNC_BEAM_SEARCH_HH +#define LEXICONFREE_TIMESYNC_BEAM_SEARCH_HH + +#include +#include +#include +#include +#include +#include + +namespace Search { + +/* + * Simple time synchronous beam search algorithm without pronunciation lexicon, word-level LM or transition model. + * Can handle a blank symbol if a blank index is set. + * Main purpose is open vocabulary search with CTC/Neural Transducer (or similar) models. + * Supports global pruning by max beam-size and by score difference to the best hypothesis. + * Uses a LabelScorer to context initialization/extension and scoring. + * + * The search requires a lexicon that represents the vocabulary. Each lemma is viewed as a token with its index + * in the lexicon corresponding to the associated output index of the label scorer. + */ +class LexiconfreeTimesyncBeamSearch : public SearchAlgorithmV2 { +protected: + /* + * Possible extension for some label hypothesis in the beam + */ + struct ExtensionCandidate { + Nn::LabelIndex nextToken; // Proposed token to extend the hypothesis with + const Bliss::LemmaPronunciation* pron; // Pronunciation of lemma corresponding to `nextToken` for traceback + Score score; // Would-be score of full hypothesis after extension + Search::TimeframeIndex timeframe; // Timestamp of `nextToken` for traceback + Nn::LabelScorer::TransitionType transitionType; // Type of transition toward `nextToken` + size_t baseHypIndex; // Index of base hypothesis in global beam + + bool operator<(ExtensionCandidate const& other) { + return score < other.score; + } + }; + + /* + * Struct containing all information about a single hypothesis in the beam + */ + struct LabelHypothesis { + Nn::ScoringContextRef scoringContext; // Context to compute scores based on this hypothesis + Nn::LabelIndex currentToken; // Most recent token in associated label sequence (useful to infer transition type) + Score score; // Full score of hypothesis + Traceback traceback; // Associated traceback to return + + LabelHypothesis(); + LabelHypothesis(LabelHypothesis const& base); + LabelHypothesis(LabelHypothesis const& base, ExtensionCandidate const& extension, Nn::ScoringContextRef const& newScoringContext); + + /* + * Get string representation for debugging + */ + std::string toString() const; + }; + + /* + * Timer to add up computation times for sub-tasks performed repeatedly + * across the search. + */ + struct TimeStatistic { + public: + // Reset accumulated total to zero. + void reset(); + + // Start timer + void tic(); + + // End running timer and add duration to total + void toc(); + + // Get total accumulated time in milliseconds + double getTotalMilliseconds() const; + + private: + std::chrono::time_point startTime; + double total; + }; + +public: + static const Core::ParameterInt paramMaxBeamSize; + static const Core::ParameterFloat paramScoreThreshold; + static const Core::ParameterInt paramBlankLabelIndex; + static const Core::ParameterBool paramAllowLabelLoop; + static const Core::ParameterBool paramUseSentenceEnd; + static const Core::ParameterBool paramSentenceEndIndex; + static const Core::ParameterBool paramLogStepwiseStatistics; + static const Core::ParameterBool paramDebugLogging; + + LexiconfreeTimesyncBeamSearch(Core::Configuration const&); + + // Inherited methods from `SearchAlgorithmV2` + + Speech::ModelCombination::Mode requiredModelCombination() const override; + bool setModelCombination(Speech::ModelCombination const& modelCombination) override; + void reset() override; + void enterSegment(Bliss::SpeechSegment const* = nullptr) override; + void finishSegment() override; + void putFeature(std::shared_ptr const& data, size_t featureSize) override; + void putFeature(std::vector const& data) override; + void putFeatures(std::shared_ptr const& data, size_t timeSize, size_t featureSize) override; + Core::Ref getCurrentBestTraceback() const override; + Core::Ref getCurrentBestWordLattice() const override; + bool decodeStep() override; + +private: + void resetStatistics(); + void logStatistics() const; + + Nn::LabelScorer::TransitionType inferTransitionType(Nn::LabelIndex prevLabel, Nn::LabelIndex nextLabel) const; + + /* + * Helper function for pruning to maxBeamSize_ + */ + void beamPruning(std::vector& extensions) const; + + /* + * Helper function for pruning to scoreThreshold_ + */ + void scorePruning(std::vector& extensions) const; + + /* + * Helper function for recombination of hypotheses with the same scoring context + */ + void recombination(std::vector& hypotheses); + + size_t maxBeamSize_; + + bool useScorePruning_; + Score scoreThreshold_; + + bool useBlank_; + Nn::LabelIndex blankLabelIndex_; + + bool allowLabelLoop_; + + bool logStepwiseStatistics_; + bool debugLogging_; + + Core::Ref labelScorer_; + Bliss::LexiconRef lexicon_; + std::vector beam_; + + TimeStatistic initializationTime_; + TimeStatistic featureProcessingTime_; + TimeStatistic scoringTime_; + TimeStatistic contextExtensionTime_; +}; + +} // namespace Search + +#endif // LEXICONFREE_BEAM_SEARCH_HH diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/Makefile b/src/Search/LexiconfreeTimesyncBeamSearch/Makefile new file mode 100644 index 00000000..c9834e9a --- /dev/null +++ b/src/Search/LexiconfreeTimesyncBeamSearch/Makefile @@ -0,0 +1,24 @@ +#!gmake + +TOPDIR = ../../.. + +include $(TOPDIR)/Makefile.cfg + +# ----------------------------------------------------------------------------- + +SUBDIRS = +TARGETS = libSprintLexiconfreeTimesyncBeamSearch.$(a) + +LIBSPRINTLEXICONFREETIMESYNCBEAMSEARCH_O = $(OBJDIR)/LexiconfreeTimesyncBeamSearch.o + + +# ----------------------------------------------------------------------------- + +all: $(TARGETS) + +libSprintLexiconfreeTimesyncBeamSearch.$(a): $(LIBSPRINTLEXICONFREETIMESYNCBEAMSEARCH_O) + $(MAKELIB) $@ $^ + +include $(TOPDIR)/Rules.make + +sinclude $(LIBSPRINTLEXICONFREETIMESYNCBEAMSEARCH_O:.o=.d) diff --git a/src/Search/Makefile b/src/Search/Makefile index 9328ef46..a8408fb3 100644 --- a/src/Search/Makefile +++ b/src/Search/Makefile @@ -33,6 +33,7 @@ LIBSPRINTSEARCH_O += $(OBJDIR)/MinimumBayesRiskAStarSearch.o LIBSPRINTSEARCH_O += $(OBJDIR)/MinimumBayesRiskNBestListSearch.o LIBSPRINTSEARCH_O += $(OBJDIR)/MinimumBayesRiskSearchUtil.o endif +SUBDIRS += LexiconfreeTimesyncBeamSearch ifdef MODULE_SEARCH_WFST SUBDIRS += Wfst endif @@ -63,6 +64,9 @@ Wfst: AdvancedTreeSearch: $(MAKE) -C $@ libSprintAdvancedTreeSearch.$(a) +LexiconfreeTimesyncBeamSearch: + $(MAKE) -C $@ libSprintLexiconfreeTimesyncBeamSearch.$(a) + include $(TOPDIR)/Rules.make sinclude $(LIBSPRINTSEARCH_O:.o=.d) diff --git a/src/Search/Traceback.hh b/src/Search/Traceback.hh index b144f337..d685938d 100644 --- a/src/Search/Traceback.hh +++ b/src/Search/Traceback.hh @@ -16,6 +16,7 @@ #define TRACEBACK_HH #include +#include #include #include #include @@ -60,7 +61,7 @@ public: : pronunciation(p), time(t), score(s), transit(te) {} }; -class Traceback : public std::vector { +class Traceback : public std::vector, public Core::ReferenceCounted { public: void write(std::ostream& os, Core::Ref) const; Fsa::ConstAutomatonRef lemmaAcceptor(Core::Ref) const; From bf0a8ce30de44ac2140b3425b10ecaf29d18c5ab Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Wed, 19 Feb 2025 19:16:58 +0100 Subject: [PATCH 02/52] Add some comments --- .../LexiconfreeTimesyncBeamSearch.cc | 2 ++ .../LexiconfreeTimesyncBeamSearch.hh | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc index 80148ab7..aaa9994d 100644 --- a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc @@ -142,6 +142,7 @@ Core::Ref LexiconfreeTimesyncBeamSearch::getCurrentBestTracebac } Core::Ref LexiconfreeTimesyncBeamSearch::getCurrentBestWordLattice() const { + // TODO: Currently this is just a linear lattice representing the best traceback. Create a proper lattice instead. if (beam_.front().traceback.empty()) { return Core::ref(new Lattice::WordLatticeAdaptor()); } @@ -439,4 +440,5 @@ void LexiconfreeTimesyncBeamSearch::TimeStatistic::toc() { double LexiconfreeTimesyncBeamSearch::TimeStatistic::getTotalMilliseconds() const { return total; } + } // namespace Search diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh index 0eb79d1a..ffff9d60 100644 --- a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh @@ -125,6 +125,10 @@ private: void resetStatistics(); void logStatistics() const; + /* + * Infer type of transition between two tokens based on whether each of them is blank + * and/or whether they are the same + */ Nn::LabelScorer::TransitionType inferTransitionType(Nn::LabelIndex prevLabel, Nn::LabelIndex nextLabel) const; /* From d6689b43a8987f077efe06be2f70b10ca3235f7d Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Thu, 20 Feb 2025 09:58:55 +0100 Subject: [PATCH 03/52] Add `createSearchAlgorithm` to Search::Module --- src/Search/Module.cc | 21 +++++++++++++++++++++ src/Search/Module.hh | 14 ++++++++++++-- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/src/Search/Module.cc b/src/Search/Module.cc index 666ce408..74aa1c29 100644 --- a/src/Search/Module.cc +++ b/src/Search/Module.cc @@ -16,6 +16,7 @@ #include #include #include +#include "LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh" #ifdef MODULE_SEARCH_WFST #include #include @@ -32,6 +33,13 @@ using namespace Search; Module_::Module_() { } +const Core::Choice Module_::searchTypeV2Choice( + "lexiconfree-timesync-beam-search", SearchTypeV2::LexiconfreeTimesyncBeamSearchType, + Core::Choice::endMark()); + +const Core::ParameterChoice Module_::searchTypeV2Param( + "type", &Module_::searchTypeV2Choice, "type of search", SearchTypeV2::LexiconfreeTimesyncBeamSearchType); + SearchAlgorithm* Module_::createRecognizer(SearchType type, const Core::Configuration& config) const { SearchAlgorithm* recognizer = 0; switch (type) { @@ -68,6 +76,19 @@ SearchAlgorithm* Module_::createRecognizer(SearchType type, const Core::Configur return recognizer; } +SearchAlgorithmV2* Module_::createSearchAlgorithm(const Core::Configuration& config) const { + SearchAlgorithmV2* searchAlgorithm = 0; + switch (searchTypeV2Param(config)) { + case LexiconfreeTimesyncBeamSearchType: + searchAlgorithm = new Search::LexiconfreeTimesyncBeamSearch(config); + break; + default: + Core::Application::us()->criticalError("Unknown search algorithm type: %d", searchTypeV2Param(config)); + break; + } + return searchAlgorithm; +} + LatticeHandler* Module_::createLatticeHandler(const Core::Configuration& c) const { LatticeHandler* handler = new LatticeHandler(c); #ifdef MODULE_SEARCH_WFST diff --git a/src/Search/Module.hh b/src/Search/Module.hh index 86be9122..40b40734 100644 --- a/src/Search/Module.hh +++ b/src/Search/Module.hh @@ -17,6 +17,7 @@ #include #include +#include "SearchV2.hh" namespace Search { @@ -30,12 +31,21 @@ enum SearchType { ExpandingFsaSearchType }; +enum SearchTypeV2 { + LexiconfreeTimesyncBeamSearchType +}; + class Module_ { +private: + static const Core::Choice searchTypeV2Choice; + static const Core::ParameterChoice searchTypeV2Param; + public: Module_(); - SearchAlgorithm* createRecognizer(SearchType type, const Core::Configuration& config) const; - LatticeHandler* createLatticeHandler(const Core::Configuration& c) const; + SearchAlgorithm* createRecognizer(SearchType type, const Core::Configuration& config) const; + SearchAlgorithmV2* createSearchAlgorithm(const Core::Configuration& config) const; + LatticeHandler* createLatticeHandler(const Core::Configuration& c) const; }; typedef Core::SingletonHolder Module; From 664945cd789d6d726224741be89b0f58c9af043b Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Wed, 26 Feb 2025 16:24:01 +0100 Subject: [PATCH 04/52] Fix compilation --- .../LexiconfreeTimesyncBeamSearch.hh | 1 - src/Speech/Makefile | 1 + src/Test/Makefile | 1 + src/Tools/Archiver/Makefile | 1 + src/Tools/NnTrainer/Makefile | 1 + 5 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh index ffff9d60..8e026a31 100644 --- a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh @@ -63,7 +63,6 @@ protected: Traceback traceback; // Associated traceback to return LabelHypothesis(); - LabelHypothesis(LabelHypothesis const& base); LabelHypothesis(LabelHypothesis const& base, ExtensionCandidate const& extension, Nn::ScoringContextRef const& newScoringContext); /* diff --git a/src/Speech/Makefile b/src/Speech/Makefile index 4b5ab352..45f03252 100644 --- a/src/Speech/Makefile +++ b/src/Speech/Makefile @@ -46,6 +46,7 @@ CHECK_O = $(OBJDIR)/check.o \ ../Mm/libSprintMm.$(a) \ ../Mc/libSprintMc.$(a) \ ../Search/libSprintSearch.$(a) \ + ../Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) \ ../Bliss/libSprintBliss.$(a) \ ../Flow/libSprintFlow.$(a) \ ../Fsa/libSprintFsa.$(a) \ diff --git a/src/Test/Makefile b/src/Test/Makefile index b9f3f2fc..52c80e5a 100644 --- a/src/Test/Makefile +++ b/src/Test/Makefile @@ -62,6 +62,7 @@ UNIT_TEST_O = $(OBJDIR)/UnitTester.o $(TEST_O) \ ../Core/libSprintCore.$(a)\ ../Speech/libSprintSpeech.$(a) \ ../Search/libSprintSearch.$(a) \ + ../Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) \ ../Lattice/libSprintLattice.$(a) \ ../Am/libSprintAm.$(a) \ ../Mm/libSprintMm.$(a) \ diff --git a/src/Tools/Archiver/Makefile b/src/Tools/Archiver/Makefile index e9000e78..6e1b762a 100644 --- a/src/Tools/Archiver/Makefile +++ b/src/Tools/Archiver/Makefile @@ -12,6 +12,7 @@ TARGETS = archiver$(exe) ARCHIVER_O = $(OBJDIR)/Archiver.o \ ../../Speech/libSprintSpeech.$(a) \ ../../Search/libSprintSearch.$(a) \ + ../../Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) \ ../../Search/AdvancedTreeSearch/libSprintAdvancedTreeSearch.$(a) \ ../../Lattice/libSprintLattice.$(a) \ ../../Lm/libSprintLm.$(a) \ diff --git a/src/Tools/NnTrainer/Makefile b/src/Tools/NnTrainer/Makefile index fc71a4bd..707285ce 100644 --- a/src/Tools/NnTrainer/Makefile +++ b/src/Tools/NnTrainer/Makefile @@ -25,6 +25,7 @@ NN_TRAINER_O = $(OBJDIR)/NnTrainer.o \ ../../Mm/libSprintMm.$(a) \ ../../Nn/libSprintNn.$(a) \ ../../Search/libSprintSearch.$(a) \ + ../../Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) \ ../../Signal/libSprintSignal.$(a) \ ../../Speech/libSprintSpeech.$(a) From 488fb0ee96e7a03d1e641f533fa4fa6c9e62b202 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Fri, 28 Feb 2025 11:56:04 +0100 Subject: [PATCH 05/52] Refactor traceback/lattice building and construct proper (nonlinear) lattice from beam --- src/Nn/LabelScorer/LabelScorer.hh | 2 + .../LexiconfreeTimesyncBeamSearch.cc | 101 ++++++++++-------- .../LexiconfreeTimesyncBeamSearch.hh | 11 +- src/Search/Traceback.cc | 101 ++++++++++++++++++ src/Search/Traceback.hh | 55 +++++++++- 5 files changed, 217 insertions(+), 53 deletions(-) diff --git a/src/Nn/LabelScorer/LabelScorer.hh b/src/Nn/LabelScorer/LabelScorer.hh index b732df8b..788edbb7 100644 --- a/src/Nn/LabelScorer/LabelScorer.hh +++ b/src/Nn/LabelScorer/LabelScorer.hh @@ -83,6 +83,8 @@ public: LABEL_TO_BLANK, BLANK_TO_LABEL, BLANK_LOOP, + INITIAL_LABEL, + INITIAL_BLANK, }; // Request for scoring or context extension diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc index aaa9994d..af112700 100644 --- a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc @@ -20,6 +20,7 @@ #include #include #include +#include "Search/Traceback.hh" namespace Search { @@ -138,42 +139,15 @@ void LexiconfreeTimesyncBeamSearch::putFeatures(std::shared_ptr con } Core::Ref LexiconfreeTimesyncBeamSearch::getCurrentBestTraceback() const { - return Core::ref(new Traceback(beam_.front().traceback)); + return beam_.front().trace->getTraceback(); } Core::Ref LexiconfreeTimesyncBeamSearch::getCurrentBestWordLattice() const { - // TODO: Currently this is just a linear lattice representing the best traceback. Create a proper lattice instead. - if (beam_.front().traceback.empty()) { - return Core::ref(new Lattice::WordLatticeAdaptor()); + std::vector> traces; + for (auto const& hyp : beam_) { + traces.push_back(hyp.trace); } - - // use default LemmaAlphabet mode of StandardWordLattice - Core::Ref result(new Lattice::StandardWordLattice(lexicon_)); - Core::Ref wordBoundaries(new Lattice::WordBoundaries); - - // create a linear lattice from the traceback - Fsa::State* currentState = result->initialState(); - for (auto it = beam_.front().traceback.begin(); it != beam_.front().traceback.end(); ++it) { - wordBoundaries->set(currentState->id(), Lattice::WordBoundary(it->time)); - Fsa::State* nextState; - if (std::next(it) == beam_.front().traceback.end()) { - nextState = result->finalState(); - } - else { - nextState = result->newState(); - } - ScoreVector scores = it->score; - if (it != beam_.front().traceback.begin()) { - scores -= std::prev(it)->score; - } - result->newArc(currentState, nextState, it->pronunciation->lemma(), scores.acoustic, scores.lm); - currentState = nextState; - } - - result->setWordBoundaries(wordBoundaries); - result->addAcyclicProperty(); - - return Core::ref(new Lattice::WordLatticeAdaptor(result)); + return buildWordLatticeFromTraces(traces, lexicon_); } void LexiconfreeTimesyncBeamSearch::resetStatistics() { @@ -197,6 +171,15 @@ Nn::LabelScorer::TransitionType LexiconfreeTimesyncBeamSearch::inferTransitionTy bool prevIsBlank = (prevLabel == blankLabelIndex_); bool nextIsBlank = (nextLabel == blankLabelIndex_); + if (prevLabel == Core::Type::max) { + if (nextIsBlank) { + return Nn::LabelScorer::TransitionType::INITIAL_BLANK; + } + else { + return Nn::LabelScorer::TransitionType::INITIAL_LABEL; + } + } + if (prevIsBlank) { if (nextIsBlank) { return Nn::LabelScorer::TransitionType::BLANK_LOOP; @@ -244,16 +227,26 @@ void LexiconfreeTimesyncBeamSearch::scorePruning(std::vector& hypotheses) { - std::vector recombinedHypotheses; + std::vector newHypotheses; - std::unordered_set seenScoringContexts; + // Map each unique ScoringContext in newHypotheses to its hypothesis + std::unordered_map seenScoringContexts; for (auto const& hyp : hypotheses) { if (seenScoringContexts.find(hyp.scoringContext) == seenScoringContexts.end()) { - recombinedHypotheses.push_back(hyp); - seenScoringContexts.insert(hyp.scoringContext); + // Hyp ScoringContext is new so it just gets pushed in + newHypotheses.push_back(hyp); + seenScoringContexts.insert({hyp.scoringContext, &newHypotheses.back()}); + } + else { + // Hyp ScoringContext already exists on a better existing hypothesis, so + // it gets merged into the existing one by adding it as a Trace sibling + verify(not hyp.trace->sibling); + auto* existingHyp = seenScoringContexts[hyp.scoringContext]; + hyp.trace->sibling = existingHyp->trace->sibling; + existingHyp->trace->sibling = hyp.trace; } } - hypotheses.swap(recombinedHypotheses); + hypotheses.swap(newHypotheses); } bool LexiconfreeTimesyncBeamSearch::decodeStep() { @@ -380,7 +373,7 @@ LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis() : scoringContext(), currentToken(Core::Type::max), score(0.0), - traceback() {} + trace() {} LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( LexiconfreeTimesyncBeamSearch::LabelHypothesis const& base, @@ -389,28 +382,42 @@ LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( : scoringContext(newScoringContext), currentToken(extension.nextToken), score(extension.score), - traceback(base.traceback) { + trace() { switch (extension.transitionType) { + case Nn::LabelScorer::INITIAL_BLANK: + case Nn::LabelScorer::INITIAL_LABEL: case Nn::LabelScorer::LABEL_TO_LABEL: case Nn::LabelScorer::LABEL_TO_BLANK: case Nn::LabelScorer::BLANK_TO_LABEL: - this->traceback.push_back(TracebackItem(extension.pron, extension.timeframe + 1u, ScoreVector(extension.score, 0.0), {})); + trace = Core::ref(new LatticeTrace( + base.trace, + extension.pron, + extension.timeframe + 1, + {extension.score, 0}, + {})); break; case Nn::LabelScorer::LABEL_LOOP: case Nn::LabelScorer::BLANK_LOOP: - if (not this->traceback.empty()) { - this->traceback.back().score.acoustic = extension.score; - this->traceback.back().time = extension.timeframe + 1u; - } + // `base.trace` is empty in the first step but at that point only `INITIAL_BLANK` and `INITIAL_LABEL` transitions can happen. + // Afterwards, `base.trace` should always be non-empty. + verify(base.trace); + + // Copy base trace and update it + trace = Core::ref(new LatticeTrace(*base.trace)); + trace->score.acoustic = extension.score; + trace->time = extension.timeframe + 1; break; } } std::string LexiconfreeTimesyncBeamSearch::LabelHypothesis::toString() const { std::stringstream ss; - ss << "Score: " << score << ", label sequence: "; - for (auto& item : traceback) { - if (item.pronunciation != nullptr) { + ss << "Score: " << score << ", traceback: "; + + auto traceback = trace->getTraceback(); + + for (auto& item : *traceback) { + if (item.pronunciation and item.pronunciation->lemma()) { ss << item.pronunciation->lemma()->symbol() << " "; } } diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh index 8e026a31..a85b7a99 100644 --- a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh @@ -21,6 +21,7 @@ #include #include #include +#include #include namespace Search { @@ -57,16 +58,16 @@ protected: * Struct containing all information about a single hypothesis in the beam */ struct LabelHypothesis { - Nn::ScoringContextRef scoringContext; // Context to compute scores based on this hypothesis - Nn::LabelIndex currentToken; // Most recent token in associated label sequence (useful to infer transition type) - Score score; // Full score of hypothesis - Traceback traceback; // Associated traceback to return + Nn::ScoringContextRef scoringContext; // Context to compute scores based on this hypothesis + Nn::LabelIndex currentToken; // Most recent token in associated label sequence (useful to infer transition type) + Score score; // Full score of hypothesis + Core::Ref trace; // Associated trace for traceback or lattice building off of hypothesis LabelHypothesis(); LabelHypothesis(LabelHypothesis const& base, ExtensionCandidate const& extension, Nn::ScoringContextRef const& newScoringContext); /* - * Get string representation for debugging + * Get string representation for debugging. */ std::string toString() const; }; diff --git a/src/Search/Traceback.cc b/src/Search/Traceback.cc index 4726bd3c..32efbdae 100644 --- a/src/Search/Traceback.cc +++ b/src/Search/Traceback.cc @@ -14,6 +14,9 @@ */ #include "Traceback.hh" +#include +#include +#include namespace Search { @@ -74,4 +77,102 @@ Lattice::WordLatticeRef Traceback::wordLattice(Core::Ref l result->setWordBoundaries(Core::ref(new Lattice::WordBoundaries)); return result; } + +LatticeTrace::LatticeTrace( + Core::Ref const& pre, + const Bliss::LemmaPronunciation* p, + Speech::TimeframeIndex t, + ScoreVector s, + Search::TracebackItem::Transit const& transit) + : TracebackItem(p, t, s, transit), predecessor(pre), sibling() {} + +Core::Ref LatticeTrace::getTraceback() const { + Core::Ref traceback; + + if (predecessor) { + traceback = predecessor->getTraceback(); + } + else { + traceback = Core::ref(new Traceback()); + traceback->push_back(TracebackItem(0, 0, {0, 0}, {})); + } + traceback->push_back(*this); + + return traceback; +} + +Core::Ref buildWordLatticeFromTraces(std::vector> const& traces, Core::Ref lexicon) { + // use default LemmaAlphabet mode of StandardWordLattice + Core::Ref result(new Lattice::StandardWordLattice(lexicon)); + Core::Ref wordBoundaries(new Lattice::WordBoundaries); + + // Map traces to lattice states + std::unordered_map stateMap; + + Fsa::State* initialState = result->initialState(); + wordBoundaries->set(initialState->id(), Lattice::WordBoundary(0)); + + Fsa::State* finalState = result->finalState(); + + // Stack for depth-first search through traces of all hypotheses in the beam + std::stack> traceStack; + + // Create a state for the current trace of each hypothesis in the beam, + // connect this state to the final state and add it to the stack + Speech::TimeframeIndex finalTime = 0; + for (auto const& trace : traces) { + auto* state = result->newState(); + stateMap[trace.get()] = state; + result->newArc(state, finalState, trace->pronunciation, 0, 0); // Score 0 for arc to final state + traceStack.push(trace); + finalTime = std::max(finalTime, trace->time + 1); + } + wordBoundaries->set(finalState->id(), Lattice::WordBoundary(finalTime)); + + // Perform depth-first search + while (not traceStack.empty()) { + auto const& trace = traceStack.top(); + traceStack.pop(); + + // A trace on the stack already has an associated state + Fsa::State* currentState = stateMap[trace.get()]; + wordBoundaries->set(currentState->id(), Lattice::WordBoundary(trace->time)); + + // Iterate through siblings of current trace + // All siblings share the same lattice state + for (auto arcTrace = trace; arcTrace; arcTrace = arcTrace->sibling) { + // For current sibling, get its predecessor, create a state for that predecessor + // and connect it to the current state. + auto& preTrace = arcTrace->predecessor; + Fsa::State* preState; + ScoreVector scores = trace->score; + if (not preTrace) { + // If trace has no predecessor, it gets connected to the initial state + preState = initialState; + } + else { + // If trace has a predecessor, get or create a state for it. Arc score + // is difference between trace scores from predecessor to current. + scores -= preTrace->score; + if (stateMap.find(preTrace.get()) == stateMap.end()) { + preState = result->newState(); + stateMap[preTrace.get()] = preState; + traceStack.push(preTrace); + } + else { + preState = stateMap[preTrace.get()]; + } + } + + // Create arc from predecessor state to current state + result->newArc(preState, currentState, arcTrace->pronunciation, scores.acoustic, scores.lm); + } + } + + result->setWordBoundaries(wordBoundaries); + result->addAcyclicProperty(); + + return Core::ref(new Lattice::WordLatticeAdaptor(result)); +} + } // namespace Search diff --git a/src/Search/Traceback.hh b/src/Search/Traceback.hh index d685938d..a61a5205 100644 --- a/src/Search/Traceback.hh +++ b/src/Search/Traceback.hh @@ -23,6 +23,9 @@ namespace Search { +/* + * Struct to join AM and LM score and allow element-wise operations. + */ struct ScoreVector { Speech::Score acoustic, lm; ScoreVector(Speech::Score a, Speech::Score l) @@ -48,6 +51,9 @@ struct ScoreVector { } }; +/* + * Data associated with a single traceback node + */ struct TracebackItem { public: typedef Lattice::WordBoundary::Transit Transit; @@ -61,7 +67,11 @@ public: : pronunciation(p), time(t), score(s), transit(te) {} }; -class Traceback : public std::vector, public Core::ReferenceCounted { +/* + * Vector of TracebackItems together with some functions for conversions and IO + */ +class Traceback : public std::vector, + public Core::ReferenceCounted { public: void write(std::ostream& os, Core::Ref) const; Fsa::ConstAutomatonRef lemmaAcceptor(Core::Ref) const; @@ -69,6 +79,49 @@ public: Lattice::WordLatticeRef wordLattice(Core::Ref) const; }; +/* + * TracebackItem together with predecessor and sibling pointers. + * Used to build the lattice after recognition. + * Siblings are traces which will share the same lattice state (i.e. hypotheses with the same ScoringContext) but + * have different predecessors. + * So a trace structure like this (where "<-" indicates a predecessor and "v" indicates a sibling) + * A <- B + * v + * A' <- B' + * will lead to a lattice like this: + * O - O + * / + * O + * + * Siblings form a chain so that the last sibling in the chain only has an empty Ref as its sibling. + * An empty Ref predecessor means that this trace will be connected to the initial lattice state. + */ +class LatticeTrace : public Core::ReferenceCounted, + public TracebackItem { +public: + Core::Ref predecessor; + Core::Ref sibling; + + LatticeTrace(Core::Ref const& pre, + Bliss::LemmaPronunciation const* p, + Speech::TimeframeIndex t, + ScoreVector s, + Transit const& transit); + + /* + * Perform best-predecessor traceback. + * Ordered by increasing timestep. + */ + Core::Ref getTraceback() const; +}; + +/* + * Build a word lattice from a set of traces. The given traces will be conencted to the final lattice + * state and they are traced back until an empty predecessor at which point they get connected to the + * initial lattice state. + */ +Core::Ref buildWordLatticeFromTraces(std::vector> const& traces, Core::Ref lexicon); + } // namespace Search #endif // TRACEBACK_HH From 1599302bf8cd21b7a449ab5ea75a8cd6a58297e1 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Fri, 28 Feb 2025 12:06:13 +0100 Subject: [PATCH 06/52] Factor out time statistics into new Core::StopWatch class --- src/Core/Makefile | 1 + src/Core/StopWatch.cc | 66 +++++++++++++++++++ src/Core/StopWatch.hh | 54 +++++++++++++++ src/Search/AdvancedTreeSearch/Helpers.hh | 35 ++-------- .../LexiconfreeTimesyncBeamSearch.cc | 60 +++++------------ .../LexiconfreeTimesyncBeamSearch.hh | 33 ++-------- 6 files changed, 150 insertions(+), 99 deletions(-) create mode 100644 src/Core/StopWatch.cc create mode 100644 src/Core/StopWatch.hh diff --git a/src/Core/Makefile b/src/Core/Makefile index 7dde3b56..75e1e30f 100644 --- a/src/Core/Makefile +++ b/src/Core/Makefile @@ -43,6 +43,7 @@ LIBSPRINTCORE_O = $(OBJDIR)/Application.o \ $(OBJDIR)/ReferenceCounting.o \ $(OBJDIR)/ResourceUsageInfo.o \ $(OBJDIR)/Statistics.o \ + $(OBJDIR)/StopWatch.o \ $(OBJDIR)/StringExpression.o \ $(OBJDIR)/StringUtilities.o \ $(OBJDIR)/TextStream.o \ diff --git a/src/Core/StopWatch.cc b/src/Core/StopWatch.cc new file mode 100644 index 00000000..843141e5 --- /dev/null +++ b/src/Core/StopWatch.cc @@ -0,0 +1,66 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "StopWatch.hh" + +namespace Core { + +StopWatch::StopWatch() + : running_(false), startTime_(), elapsedTime_(0.0) {} + +void StopWatch::start() { + if (running_) { + return; + } + + startTime_ = std::chrono::steady_clock::now(); + running_ = true; +} + +void StopWatch::stop() { + if (not running_) { + return; + } + auto endTime = std::chrono::steady_clock::now(); + elapsedTime_ += std::chrono::duration_cast(endTime - startTime_).count(); + running_ = false; +} + +void StopWatch::reset() { + elapsedTime_ = 0; + running_ = false; +} + +double StopWatch::elapsedSeconds() const { + return elapsedTime_ / 1e9; +} + +double StopWatch::elapsedCentiseconds() const { + return elapsedTime_ / 1e7; +} + +double StopWatch::elapsedMilliseconds() const { + return elapsedTime_ / 1e6; +} + +double StopWatch::elapsedMicroseconds() const { + return elapsedTime_ / 1e3; +} + +double StopWatch::elapsedNanoseconds() const { + return elapsedTime_; +} + +} // namespace Core diff --git a/src/Core/StopWatch.hh b/src/Core/StopWatch.hh new file mode 100644 index 00000000..7b58e3d0 --- /dev/null +++ b/src/Core/StopWatch.hh @@ -0,0 +1,54 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef STOPWATCH_HH +#define STOPWATCH_HH + +#include + +namespace Core { + +/* + * Timer to add up computation times for sub-tasks performed repeatedly + * across the search. + */ +struct StopWatch { +public: + StopWatch(); + + // Reset accumulated total to zero. + void reset(); + + // Start timer + void start(); + + // End running timer and add duration to total + void stop(); + + // Getter functions to get the total elapsed time in different units + double elapsedSeconds() const; + double elapsedCentiseconds() const; + double elapsedMilliseconds() const; + double elapsedMicroseconds() const; + double elapsedNanoseconds() const; + +private: + bool running_; + std::chrono::steady_clock::time_point startTime_; + double elapsedTime_; +}; + +} // namespace Core + +#endif // TIMER_HH diff --git a/src/Search/AdvancedTreeSearch/Helpers.hh b/src/Search/AdvancedTreeSearch/Helpers.hh index 293d4bb5..ebc35e20 100644 --- a/src/Search/AdvancedTreeSearch/Helpers.hh +++ b/src/Search/AdvancedTreeSearch/Helpers.hh @@ -16,6 +16,7 @@ #define HELPERS_HH #include +#include #include #include #include @@ -36,12 +37,10 @@ class Configuration; bool isBackwardRecognition(const Core::Configuration& config); -class PerformanceCounter { +class PerformanceCounter : public Core::StopWatch { public: PerformanceCounter(Search::SearchSpaceStatistics& stats, const std::string& name, bool start = true) - : running_(false), - totalTime_(0), - timeStats(stats.customStatistics("Profiling: " + name + ": Centiseconds")) { + : Core::StopWatch(), timeStats(stats.customStatistics("Profiling: " + name + ": Centiseconds")) { if (start) this->start(); } @@ -50,32 +49,13 @@ public: stopAndYield(); } - void start() { - stop(); - - running_ = true; - TIMER_START(starttime_); - } - - void stop() { - if (running_) { - running_ = false; - - double diff = 0; // in secs - timeval end; - - TIMER_STOP(starttime_, end, diff); - totalTime_ += diff * 100; // centi secs - } - } - /// Prints the current instruction count to the statistics object void stopAndYield(bool print = false) { stop(); - timeStats += totalTime_; + timeStats += elapsedCentiseconds(); if (print) - std::cout << " time: " << totalTime_ << std::endl; - totalTime_ = 0; + std::cout << " time: " << elapsedCentiseconds() << std::endl; + reset(); } static inline u64 instructions() { @@ -90,9 +70,6 @@ public: } private: - bool running_; - timeval starttime_; - f32 totalTime_; Core::Statistics& timeStats; }; diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc index af112700..1895452a 100644 --- a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc @@ -81,7 +81,7 @@ LexiconfreeTimesyncBeamSearch::LexiconfreeTimesyncBeamSearch(Core::Configuration } void LexiconfreeTimesyncBeamSearch::reset() { - initializationTime_.tic(); + initializationTime_.start(); labelScorer_->reset(); @@ -90,7 +90,7 @@ void LexiconfreeTimesyncBeamSearch::reset() { beam_.push_back(LabelHypothesis()); beam_.front().scoringContext = labelScorer_->getInitialScoringContext(); - initializationTime_.toc(); + initializationTime_.stop(); } Speech::ModelCombination::Mode LexiconfreeTimesyncBeamSearch::requiredModelCombination() const { @@ -106,36 +106,36 @@ bool LexiconfreeTimesyncBeamSearch::setModelCombination(Speech::ModelCombination } void LexiconfreeTimesyncBeamSearch::enterSegment(Bliss::SpeechSegment const* segment) { - initializationTime_.tic(); + initializationTime_.start(); labelScorer_->reset(); resetStatistics(); - initializationTime_.toc(); + initializationTime_.stop(); } void LexiconfreeTimesyncBeamSearch::finishSegment() { - featureProcessingTime_.tic(); + featureProcessingTime_.start(); labelScorer_->signalNoMoreFeatures(); - featureProcessingTime_.toc(); + featureProcessingTime_.stop(); decodeManySteps(); logStatistics(); } void LexiconfreeTimesyncBeamSearch::putFeature(std::shared_ptr const& data, size_t featureSize) { - featureProcessingTime_.tic(); + featureProcessingTime_.start(); labelScorer_->addInput(data, featureSize); - featureProcessingTime_.toc(); + featureProcessingTime_.stop(); } void LexiconfreeTimesyncBeamSearch::putFeature(std::vector const& data) { - featureProcessingTime_.tic(); + featureProcessingTime_.start(); labelScorer_->addInput(data); - featureProcessingTime_.toc(); + featureProcessingTime_.stop(); } void LexiconfreeTimesyncBeamSearch::putFeatures(std::shared_ptr const& data, size_t timeSize, size_t featureSize) { - featureProcessingTime_.tic(); + featureProcessingTime_.start(); labelScorer_->addInputs(data, timeSize, featureSize); - featureProcessingTime_.toc(); + featureProcessingTime_.stop(); } Core::Ref LexiconfreeTimesyncBeamSearch::getCurrentBestTraceback() const { @@ -159,10 +159,10 @@ void LexiconfreeTimesyncBeamSearch::resetStatistics() { void LexiconfreeTimesyncBeamSearch::logStatistics() const { clog() << Core::XmlOpen("timing-statistics") + Core::XmlAttribute("unit", "milliseconds"); - clog() << Core::XmlOpen("initialization-time") << initializationTime_.getTotalMilliseconds() << Core::XmlClose("initialization-time"); - clog() << Core::XmlOpen("feature-processing-time") << featureProcessingTime_.getTotalMilliseconds() << Core::XmlClose("feature-processing-time"); - clog() << Core::XmlOpen("scoring-time") << scoringTime_.getTotalMilliseconds() << Core::XmlClose("scoring-time"); - clog() << Core::XmlOpen("context-extension-time") << contextExtensionTime_.getTotalMilliseconds() << Core::XmlClose("context-extension-time"); + clog() << Core::XmlOpen("initialization-time") << initializationTime_.elapsedMilliseconds() << Core::XmlClose("initialization-time"); + clog() << Core::XmlOpen("feature-processing-time") << featureProcessingTime_.elapsedMilliseconds() << Core::XmlClose("feature-processing-time"); + clog() << Core::XmlOpen("scoring-time") << scoringTime_.elapsedMilliseconds() << Core::XmlClose("scoring-time"); + clog() << Core::XmlOpen("context-extension-time") << contextExtensionTime_.elapsedMilliseconds() << Core::XmlClose("context-extension-time"); clog() << Core::XmlClose("timing-statistics"); } @@ -290,9 +290,9 @@ bool LexiconfreeTimesyncBeamSearch::decodeStep() { /* * Perform scoring of all the requests with the label scorer. */ - scoringTime_.tic(); + scoringTime_.start(); auto result = labelScorer_->computeScoresWithTimes(requests); - scoringTime_.toc(); + scoringTime_.stop(); if (not result) { // LabelScorer could not compute scores -> no search step can be made. @@ -424,28 +424,4 @@ std::string LexiconfreeTimesyncBeamSearch::LabelHypothesis::toString() const { return ss.str(); } -/* - * ======================= - * === TimeStatistic ===== - * ======================= - */ - -void LexiconfreeTimesyncBeamSearch::TimeStatistic::reset() { - total = 0.0; -} - -void LexiconfreeTimesyncBeamSearch::TimeStatistic::tic() { - startTime = std::chrono::steady_clock::now(); -} - -void LexiconfreeTimesyncBeamSearch::TimeStatistic::toc() { - auto endTime = std::chrono::steady_clock::now(); - // Duration in milliseconds - total += std::chrono::duration_cast>>(endTime - startTime).count(); -} - -double LexiconfreeTimesyncBeamSearch::TimeStatistic::getTotalMilliseconds() const { - return total; -} - } // namespace Search diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh index a85b7a99..52d8c7e1 100644 --- a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh @@ -18,11 +18,11 @@ #include #include +#include #include #include #include #include -#include namespace Search { @@ -72,29 +72,6 @@ protected: std::string toString() const; }; - /* - * Timer to add up computation times for sub-tasks performed repeatedly - * across the search. - */ - struct TimeStatistic { - public: - // Reset accumulated total to zero. - void reset(); - - // Start timer - void tic(); - - // End running timer and add duration to total - void toc(); - - // Get total accumulated time in milliseconds - double getTotalMilliseconds() const; - - private: - std::chrono::time_point startTime; - double total; - }; - public: static const Core::ParameterInt paramMaxBeamSize; static const Core::ParameterFloat paramScoreThreshold; @@ -163,10 +140,10 @@ private: Bliss::LexiconRef lexicon_; std::vector beam_; - TimeStatistic initializationTime_; - TimeStatistic featureProcessingTime_; - TimeStatistic scoringTime_; - TimeStatistic contextExtensionTime_; + Core::StopWatch initializationTime_; + Core::StopWatch featureProcessingTime_; + Core::StopWatch scoringTime_; + Core::StopWatch contextExtensionTime_; }; } // namespace Search From 9a609169753f20d3b65550f71887e90398a66dbb Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Fri, 28 Feb 2025 13:00:25 +0100 Subject: [PATCH 07/52] Don't copy sibling from predecessor --- .../LexiconfreeTimesyncBeamSearch.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc index 1895452a..d5a062da 100644 --- a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc @@ -404,6 +404,7 @@ LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( // Copy base trace and update it trace = Core::ref(new LatticeTrace(*base.trace)); + trace->sibling = {}; trace->score.acoustic = extension.score; trace->time = extension.timeframe + 1; break; From 8e964234fb5dc3d465f4cf9174d8da39c33b9f42 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Fri, 28 Feb 2025 13:20:13 +0100 Subject: [PATCH 08/52] Better handling of blank index --- .../LexiconfreeTimesyncBeamSearch.cc | 24 +++++++++++++++---- .../LexiconfreeTimesyncBeamSearch.hh | 2 +- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc index d5a062da..8bf0943d 100644 --- a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc @@ -42,7 +42,7 @@ const Core::ParameterFloat LexiconfreeTimesyncBeamSearch::paramScoreThreshold( const Core::ParameterInt LexiconfreeTimesyncBeamSearch::paramBlankLabelIndex( "blank-label-index", - "Index of the blank label in the lexicon. If not set, the search will not use blank.", + "Index of the blank label in the lexicon. Can also be inferred from lexicon if it has a lemma with `special='blank'`. If not set, the search will not use blank.", Core::Type::max); const Core::ParameterBool LexiconfreeTimesyncBeamSearch::paramAllowLabelLoop( @@ -76,7 +76,10 @@ LexiconfreeTimesyncBeamSearch::LexiconfreeTimesyncBeamSearch(Core::Configuration scoringTime_(), contextExtensionTime_() { beam_.reserve(maxBeamSize_); - useBlank_ = blankLabelIndex_ != Core::Type::max; + useBlank_ = blankLabelIndex_ != Core::Type::max; + if (useBlank_) { + log() << "Use blank label with index " << blankLabelIndex_; + } useScorePruning_ = scoreThreshold_ != Core::Type::max; } @@ -101,6 +104,18 @@ bool LexiconfreeTimesyncBeamSearch::setModelCombination(Speech::ModelCombination lexicon_ = modelCombination.lexicon(); labelScorer_ = modelCombination.labelScorer(); + auto blankLemma = lexicon_->specialLemma("blank"); + if (blankLemma) { + if (blankLabelIndex_ == Core::Type::max) { + blankLabelIndex_ = blankLemma->id(); + useBlank_ = true; + log() << "Use blank index " << blankLabelIndex_ << " inferred from lexicon"; + } + else if (blankLabelIndex_ != blankLemma->id()) { + warning() << "Blank lemma exists in lexicon with id " << blankLemma->id() << " but is overwritten by config parameter with value " << blankLabelIndex_; + } + } + reset(); return true; } @@ -167,9 +182,8 @@ void LexiconfreeTimesyncBeamSearch::logStatistics() const { } Nn::LabelScorer::TransitionType LexiconfreeTimesyncBeamSearch::inferTransitionType(Nn::LabelIndex prevLabel, Nn::LabelIndex nextLabel) const { - // These checks will result in false if `blankLabelIndex_` is still `Core::Type::max`, i.e. no blank is used - bool prevIsBlank = (prevLabel == blankLabelIndex_); - bool nextIsBlank = (nextLabel == blankLabelIndex_); + bool prevIsBlank = (useBlank_ and prevLabel == blankLabelIndex_); + bool nextIsBlank = (useBlank_ and nextLabel == blankLabelIndex_); if (prevLabel == Core::Type::max) { if (nextIsBlank) { diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh index 52d8c7e1..6d2dd2f0 100644 --- a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh @@ -148,4 +148,4 @@ private: } // namespace Search -#endif // LEXICONFREE_BEAM_SEARCH_HH +#endif // LEXICONFREE_TIMESYNC_BEAM_SEARCH_HH From 536ac826eafbe94fd412dd8bd1ef147aa563c85c Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Fri, 28 Feb 2025 13:28:27 +0100 Subject: [PATCH 09/52] Apply suggestions from code review --- .../LexiconfreeTimesyncBeamSearch.cc | 11 ++++++++--- .../LexiconfreeTimesyncBeamSearch.hh | 2 ++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc index 8bf0943d..0e621bea 100644 --- a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc @@ -220,7 +220,7 @@ void LexiconfreeTimesyncBeamSearch::beamPruning(std::vectorextendedScoringContext({baseHyp.scoringContext, extension.nextToken, extension.transitionType}); + auto const& baseHyp = beam_[extension.baseHypIndex]; + + auto newScoringContext = labelScorer_->extendedScoringContext( + {baseHyp.scoringContext, + extension.nextToken, + extension.transitionType}); + newBeam.push_back({baseHyp, extension, newScoringContext}); } diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh index 6d2dd2f0..71f19317 100644 --- a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh @@ -115,11 +115,13 @@ private: /* * Helper function for pruning to scoreThreshold_ + * Requires that the input extensions are already sorted by score */ void scorePruning(std::vector& extensions) const; /* * Helper function for recombination of hypotheses with the same scoring context + * Requires that the input hypotheses are already sorted by score */ void recombination(std::vector& hypotheses); From f21935e39111cfbc17648ceb826dd18b94d3af9b Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Tue, 4 Mar 2025 13:51:28 +0100 Subject: [PATCH 10/52] Implement StopWatch class --- src/Core/Makefile | 1 + src/Core/StopWatch.cc | 70 +++++++++++++++++++++++++++++++++++++++++++ src/Core/StopWatch.hh | 63 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 134 insertions(+) create mode 100644 src/Core/StopWatch.cc create mode 100644 src/Core/StopWatch.hh diff --git a/src/Core/Makefile b/src/Core/Makefile index 7dde3b56..75e1e30f 100644 --- a/src/Core/Makefile +++ b/src/Core/Makefile @@ -43,6 +43,7 @@ LIBSPRINTCORE_O = $(OBJDIR)/Application.o \ $(OBJDIR)/ReferenceCounting.o \ $(OBJDIR)/ResourceUsageInfo.o \ $(OBJDIR)/Statistics.o \ + $(OBJDIR)/StopWatch.o \ $(OBJDIR)/StringExpression.o \ $(OBJDIR)/StringUtilities.o \ $(OBJDIR)/TextStream.o \ diff --git a/src/Core/StopWatch.cc b/src/Core/StopWatch.cc new file mode 100644 index 00000000..3ed2e872 --- /dev/null +++ b/src/Core/StopWatch.cc @@ -0,0 +1,70 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "StopWatch.hh" + +namespace Core { + +StopWatch::StopWatch() + : running_(false), startTime_(), elapsedTime_(0.0) {} + +void StopWatch::start() { + if (running_) { + return; + } + + startTime_ = std::chrono::steady_clock::now(); + running_ = true; +} + +void StopWatch::stop() { + if (not running_) { + return; + } + auto endTime = std::chrono::steady_clock::now(); + elapsedTime_ += std::chrono::duration_cast(endTime - startTime_).count(); + running_ = false; +} + +void StopWatch::reset() { + elapsedTime_ = 0; + running_ = false; +} + +double StopWatch::elapsedSeconds() const { + return elapsedNanoseconds() / 1e9; +} + +double StopWatch::elapsedCentiseconds() const { + return elapsedNanoseconds() / 1e7; +} + +double StopWatch::elapsedMilliseconds() const { + return elapsedNanoseconds() / 1e6; +} + +double StopWatch::elapsedMicroseconds() const { + return elapsedNanoseconds() / 1e3; +} + +double StopWatch::elapsedNanoseconds() const { + if (running_) { + auto currentTime = std::chrono::steady_clock::now(); + return elapsedTime_ + std::chrono::duration_cast(currentTime - startTime_).count(); + } + return elapsedTime_; +} + +} // namespace Core diff --git a/src/Core/StopWatch.hh b/src/Core/StopWatch.hh new file mode 100644 index 00000000..5851f55a --- /dev/null +++ b/src/Core/StopWatch.hh @@ -0,0 +1,63 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef STOPWATCH_HH +#define STOPWATCH_HH + +#include + +namespace Core { + +/* + * Simple timer class with start/stop functions that accumulates all the timed intervals + * to a total. + */ +struct StopWatch { +public: + StopWatch(); + + /* + * Stops timer if it is running and resets accumulated time to zero. + */ + void reset(); + + /* + * Start timer. Does nothing if timer is already running. + */ + void start(); + + /* + * End running timer and add duration to total. Does nothing if timer is not running. + */ + void stop(); + + /* + * Getter functions to get the total elapsed time in different units. Includes the current interval + * if the timer is running. + */ + double elapsedSeconds() const; + double elapsedCentiseconds() const; + double elapsedMilliseconds() const; + double elapsedMicroseconds() const; + double elapsedNanoseconds() const; + +private: + bool running_; + std::chrono::steady_clock::time_point startTime_; + double elapsedTime_; +}; + +} // namespace Core + +#endif // TIMER_HH From 5f82460b2ed00c0869aab6911da423b4ea810ae5 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Tue, 4 Mar 2025 14:13:59 +0100 Subject: [PATCH 11/52] Use TIMER_START and TIMER_STOP macros instead --- src/Core/StopWatch.cc | 48 ++++++++++++++++++++++++------------------- src/Core/StopWatch.hh | 18 ++++++++-------- 2 files changed, 36 insertions(+), 30 deletions(-) diff --git a/src/Core/StopWatch.cc b/src/Core/StopWatch.cc index 3ed2e872..91979e9f 100644 --- a/src/Core/StopWatch.cc +++ b/src/Core/StopWatch.cc @@ -14,57 +14,63 @@ */ #include "StopWatch.hh" +#include "Utility.hh" namespace Core { StopWatch::StopWatch() - : running_(false), startTime_(), elapsedTime_(0.0) {} + : running_(false), startTime_(), elapsedSeconds_(0.0) {} void StopWatch::start() { if (running_) { return; } - startTime_ = std::chrono::steady_clock::now(); - running_ = true; + TIMER_START(startTime_); + running_ = true; } void StopWatch::stop() { if (not running_) { return; } - auto endTime = std::chrono::steady_clock::now(); - elapsedTime_ += std::chrono::duration_cast(endTime - startTime_).count(); + + timeval endTime; + double diff = 0; // in seconds + TIMER_STOP(startTime_, endTime, diff); + + elapsedSeconds_ += diff; running_ = false; } void StopWatch::reset() { - elapsedTime_ = 0; - running_ = false; + elapsedSeconds_ = 0; + running_ = false; } -double StopWatch::elapsedSeconds() const { - return elapsedNanoseconds() / 1e9; +double StopWatch::elapsedSeconds() { + if (running_) { + timeval endTime; + double diff = 0; // in seconds + TIMER_STOP(startTime_, endTime, diff); + } + return elapsedSeconds_; } -double StopWatch::elapsedCentiseconds() const { - return elapsedNanoseconds() / 1e7; +double StopWatch::elapsedCentiseconds() { + return elapsedSeconds() * 1e2; } -double StopWatch::elapsedMilliseconds() const { - return elapsedNanoseconds() / 1e6; +double StopWatch::elapsedMilliseconds() { + return elapsedSeconds() * 1e3; } -double StopWatch::elapsedMicroseconds() const { - return elapsedNanoseconds() / 1e3; +double StopWatch::elapsedMicroseconds() { + return elapsedSeconds() * 1e6; } -double StopWatch::elapsedNanoseconds() const { - if (running_) { - auto currentTime = std::chrono::steady_clock::now(); - return elapsedTime_ + std::chrono::duration_cast(currentTime - startTime_).count(); - } - return elapsedTime_; +double StopWatch::elapsedNanoseconds() { + return elapsedSeconds() * 1e9; } } // namespace Core diff --git a/src/Core/StopWatch.hh b/src/Core/StopWatch.hh index 5851f55a..a0ed87cb 100644 --- a/src/Core/StopWatch.hh +++ b/src/Core/StopWatch.hh @@ -15,7 +15,7 @@ #ifndef STOPWATCH_HH #define STOPWATCH_HH -#include +#include namespace Core { @@ -46,16 +46,16 @@ public: * Getter functions to get the total elapsed time in different units. Includes the current interval * if the timer is running. */ - double elapsedSeconds() const; - double elapsedCentiseconds() const; - double elapsedMilliseconds() const; - double elapsedMicroseconds() const; - double elapsedNanoseconds() const; + double elapsedSeconds(); + double elapsedCentiseconds(); + double elapsedMilliseconds(); + double elapsedMicroseconds(); + double elapsedNanoseconds(); private: - bool running_; - std::chrono::steady_clock::time_point startTime_; - double elapsedTime_; + bool running_; + timeval startTime_; + double elapsedSeconds_; }; } // namespace Core From 4779dd5558a4f0ee42ae691fb82436fd30f07aeb Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Tue, 4 Mar 2025 14:14:39 +0100 Subject: [PATCH 12/52] Simplify AdvancedTreeSearch PerformanceCounter by inheriting from StopWatch --- src/Search/AdvancedTreeSearch/Helpers.hh | 54 +++++------------------- 1 file changed, 10 insertions(+), 44 deletions(-) diff --git a/src/Search/AdvancedTreeSearch/Helpers.hh b/src/Search/AdvancedTreeSearch/Helpers.hh index 293d4bb5..2c01db83 100644 --- a/src/Search/AdvancedTreeSearch/Helpers.hh +++ b/src/Search/AdvancedTreeSearch/Helpers.hh @@ -16,10 +16,9 @@ #define HELPERS_HH #include +#include #include -#include #include -#include #include "SearchSpaceStatistics.hh" namespace Search { @@ -36,63 +35,30 @@ class Configuration; bool isBackwardRecognition(const Core::Configuration& config); -class PerformanceCounter { +class PerformanceCounter : public Core::StopWatch { public: PerformanceCounter(Search::SearchSpaceStatistics& stats, const std::string& name, bool start = true) - : running_(false), - totalTime_(0), - timeStats(stats.customStatistics("Profiling: " + name + ": Centiseconds")) { - if (start) + : Core::StopWatch(), timeStats(stats.customStatistics("Profiling: " + name + ": Centiseconds")) { + if (start) { this->start(); + } } ~PerformanceCounter() { stopAndYield(); } - void start() { - stop(); - - running_ = true; - TIMER_START(starttime_); - } - - void stop() { - if (running_) { - running_ = false; - - double diff = 0; // in secs - timeval end; - - TIMER_STOP(starttime_, end, diff); - totalTime_ += diff * 100; // centi secs - } - } - /// Prints the current instruction count to the statistics object void stopAndYield(bool print = false) { stop(); - timeStats += totalTime_; - if (print) - std::cout << " time: " << totalTime_ << std::endl; - totalTime_ = 0; - } - - static inline u64 instructions() { - unsigned int a, d; - asm __volatile__("" - : - : - : "memory"); - asm volatile("rdtsc" - : "=a"(a), "=d"(d)); - return ((u64)a) | (((u64)d) << 32); + timeStats += elapsedCentiseconds(); + if (print) { + std::cout << " time: " << elapsedCentiseconds() << std::endl; + } + reset(); } private: - bool running_; - timeval starttime_; - f32 totalTime_; Core::Statistics& timeStats; }; From f5a318261732c82ba99798dc6f862e3f1c1fc496 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Tue, 4 Mar 2025 14:27:53 +0100 Subject: [PATCH 13/52] Small fixes in StopWatch class --- src/Core/StopWatch.cc | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/Core/StopWatch.cc b/src/Core/StopWatch.cc index 91979e9f..6e02cdae 100644 --- a/src/Core/StopWatch.cc +++ b/src/Core/StopWatch.cc @@ -36,10 +36,8 @@ void StopWatch::stop() { } timeval endTime; - double diff = 0; // in seconds - TIMER_STOP(startTime_, endTime, diff); + TIMER_STOP(startTime_, endTime, elapsedSeconds_); - elapsedSeconds_ += diff; running_ = false; } @@ -51,8 +49,12 @@ void StopWatch::reset() { double StopWatch::elapsedSeconds() { if (running_) { timeval endTime; - double diff = 0; // in seconds - TIMER_STOP(startTime_, endTime, diff); + double currentTime = 0; // in seconds + + // Note: This macro doesn't actually "stop" anything, it just writes into `endTime` and `currentTime` + TIMER_STOP(startTime_, endTime, currentTime); + + return elapsedSeconds_ + currentTime; } return elapsedSeconds_; } From 97e5bd7dad3c99588497aa0b9dc05574c0371800 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Tue, 4 Mar 2025 14:33:11 +0100 Subject: [PATCH 14/52] Make StopWatch a member of PerformanceCounter instead of inheriting --- src/Search/AdvancedTreeSearch/Helpers.hh | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/src/Search/AdvancedTreeSearch/Helpers.hh b/src/Search/AdvancedTreeSearch/Helpers.hh index 2c01db83..6bba739f 100644 --- a/src/Search/AdvancedTreeSearch/Helpers.hh +++ b/src/Search/AdvancedTreeSearch/Helpers.hh @@ -35,12 +35,12 @@ class Configuration; bool isBackwardRecognition(const Core::Configuration& config); -class PerformanceCounter : public Core::StopWatch { +class PerformanceCounter { public: PerformanceCounter(Search::SearchSpaceStatistics& stats, const std::string& name, bool start = true) - : Core::StopWatch(), timeStats(stats.customStatistics("Profiling: " + name + ": Centiseconds")) { + : stopWatch_(), timeStats_(stats.customStatistics("Profiling: " + name + ": Centiseconds")) { if (start) { - this->start(); + stopWatch_.start(); } } @@ -48,18 +48,28 @@ public: stopAndYield(); } + void start() { + stopWatch_.stop(); + stopWatch_.start(); + } + + void stop() { + stopWatch_.stop(); + } + /// Prints the current instruction count to the statistics object void stopAndYield(bool print = false) { stop(); - timeStats += elapsedCentiseconds(); + timeStats_ += stopWatch_.elapsedCentiseconds(); if (print) { - std::cout << " time: " << elapsedCentiseconds() << std::endl; + std::cout << " time: " << stopWatch_.elapsedCentiseconds() << std::endl; } - reset(); + stopWatch_.reset(); } private: - Core::Statistics& timeStats; + Core::StopWatch stopWatch_; + Core::Statistics& timeStats_; }; inline f32 scaledLogAdd(f32 a, f32 b, f32 scale, f32 invertedScale) { From b77cf23c274cfed144b40d0dda5ff87c5c1d67d7 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Tue, 4 Mar 2025 17:24:17 +0100 Subject: [PATCH 15/52] Implement LatticeTrace class --- src/Search/Traceback.cc | 113 ++++++++++++++++++++++++++++++++++++++++ src/Search/Traceback.hh | 78 ++++++++++++++++++++++++++- 2 files changed, 190 insertions(+), 1 deletion(-) diff --git a/src/Search/Traceback.cc b/src/Search/Traceback.cc index 4726bd3c..10afcae4 100644 --- a/src/Search/Traceback.cc +++ b/src/Search/Traceback.cc @@ -14,6 +14,9 @@ */ #include "Traceback.hh" +#include +#include +#include namespace Search { @@ -74,4 +77,114 @@ Lattice::WordLatticeRef Traceback::wordLattice(Core::Ref l result->setWordBoundaries(Core::ref(new Lattice::WordBoundaries)); return result; } + +LatticeTrace::LatticeTrace( + Core::Ref const& pre, + const Bliss::LemmaPronunciation* p, + Speech::TimeframeIndex t, + ScoreVector s, + Search::TracebackItem::Transit const& transit) + : TracebackItem(p, t, s, transit), predecessor(pre), sibling() {} + +Core::Ref LatticeTrace::getPredecessor() const { + return predecessor_.get(); +} + +Core::Ref LatticeTrace::getSibling() const { + return sibling_.get(); +} + +void LatticeTrace::appendSiblingToChain(Core::Ref sibling) { + if (sibling_) { + sibling_.appendSiblingToChain(sibling); + } + else { + sibling_ = sibling; + } +} + +Core::Ref LatticeTrace::getTraceback() const { + Core::Ref traceback; + + if (predecessor) { + traceback = predecessor->getTraceback(); + } + else { + traceback = Core::ref(new Traceback()); + traceback->push_back(TracebackItem(0, 0, {0, 0}, {})); + } + traceback->push_back(*this); + + return traceback; +} + +Core::Ref LatticeTrace::buildWordLattice(Core::Ref lexicon) { + // use default LemmaAlphabet mode of StandardWordLattice + Core::Ref result(new Lattice::StandardWordLattice(lexicon)); + Core::Ref wordBoundaries(new Lattice::WordBoundaries); + + // Map traces to lattice states + std::unordered_map stateMap; + + // Create an initial State at time 0 which represents empty predecessors + Fsa::State* initialState = result->initialState(); + wordBoundaries->set(initialState->id(), Lattice::WordBoundary(0)); + + // Stack for depth-first search through traces of all hypotheses in the beam + std::stack traceStack; + + // Create a final state which represents this trace itself + Fsa::State* finalState = result->finalState(); + stateMap[this] = finalState; + traceStack.push(this); + wordBoundaries->set(finalState->id(), Lattice::WordBoundary(this->time)); + + // Perform depth-first search + while (not traceStack.empty()) { + auto* trace = traceStack.top(); + traceStack.pop(); + + // A trace on the stack already has an associated state + Fsa::State* currentState = stateMap[trace]; + wordBoundaries->set(currentState->id(), Lattice::WordBoundary(trace->time)); + + // Iterate through siblings of current trace + // All siblings share the same lattice state + for (auto arcTrace = trace; arcTrace != nullptr; arcTrace = arcTrace->getSibling().get()) { + // For current sibling, get its predecessor, create a state for that predecessor + // and connect it to the current state. + auto* preTrace = arcTrace->getPredecessor().get(); + Fsa::State* preState; + ScoreVector scores = trace->score; + if (preTrace == nullptr) { + // If trace has no predecessor, it gets connected to the initial state + preState = initialState; + } + else { + // If trace has a predecessor, get or create a state for it. Arc score + // is difference between trace scores from predecessor to current. + scores -= preTrace->score; + if (stateMap.find(preTrace) == stateMap.end()) { + preState = result->newState(); + stateMap[preTrace] = preState; + traceStack.push(preTrace); + } + else { + preState = stateMap[preTrace.get()]; + } + } + + // Create arc from predecessor state to current state + result->newArc(preState, currentState, arcTrace->pronunciation, scores.acoustic, scores.lm); + } + } + + result->setWordBoundaries(wordBoundaries); + result->addAcyclicProperty(); + + return Core::ref(new Lattice::WordLatticeAdaptor(result)); +} + +LatticeTrace:: + } // namespace Search diff --git a/src/Search/Traceback.hh b/src/Search/Traceback.hh index b144f337..e6e0fb46 100644 --- a/src/Search/Traceback.hh +++ b/src/Search/Traceback.hh @@ -16,12 +16,16 @@ #define TRACEBACK_HH #include +#include #include #include #include namespace Search { +/* + * Struct to join AM and LM score and allow element-wise operations. + */ struct ScoreVector { Speech::Score acoustic, lm; ScoreVector(Speech::Score a, Speech::Score l) @@ -47,6 +51,9 @@ struct ScoreVector { } }; +/* + * Data associated with a single traceback node + */ struct TracebackItem { public: typedef Lattice::WordBoundary::Transit Transit; @@ -60,7 +67,11 @@ public: : pronunciation(p), time(t), score(s), transit(te) {} }; -class Traceback : public std::vector { +/* + * Vector of TracebackItems together with some functions for conversions and IO + */ +class Traceback : public std::vector, + public Core::ReferenceCounted { public: void write(std::ostream& os, Core::Ref) const; Fsa::ConstAutomatonRef lemmaAcceptor(Core::Ref) const; @@ -68,6 +79,71 @@ public: Lattice::WordLatticeRef wordLattice(Core::Ref) const; }; +/* + * TracebackItem together with predecessor and sibling pointers. + * Used to build the lattice after recognition. + * Siblings are traces which will share the same lattice state (i.e. hypotheses with the same ScoringContext) but + * have different predecessors. + * So a trace structure like this (where "<-" indicates a predecessor and "v" indicates a sibling) + * A <- B + * v + * A' <- B' + * will lead to a lattice like this: + * O - O + * / + * O + * + * Siblings form a chain so that the last sibling in the chain only has an empty Ref as its sibling. + * An empty Ref predecessor means that this trace will be connected to the initial lattice state. + * + * Note: Don't connect traces as siblings or predecessor of each other in a circular way as this may result + * in infinite loops during traversal. + */ +class LatticeTrace : public Core::ReferenceCounted, + public TracebackItem { +public: + LatticeTrace(Core::Ref const& pre, + Bliss::LemmaPronunciation const* p, + Speech::TimeframeIndex t, + ScoreVector s, + Transit const& transit); + + /* + * Getter functions + */ + LatticeTrace* getPredecessor() const; + LatticeTrace* getSibling() const; + + /* + * Append sibling chain to the end of the own sibling chain + * Example: If we have sibling chains + * + * A -> B -> C and D -> E + * + * then after A.appendSibling(D) it will be + * + * A -> B -> C -> D -> E + */ + void appendSiblingToChain(Core::Ref sibling); + + /* + * Perform best-predecessor traceback. + * Ordered by increasing timestep. + */ + Core::Ref performTraceback() const; + + /* + * Build a word lattice from a traces. The given trace will be represent the final lattice + * state and it is traced back along predecessors and siblings until ending up at the empty predecessor which + * is represented as the initial lattice state. + */ + Core::Ref buildWordLattice(Core::Ref lexicon) const; + +private: + Core::Ref predecessor_; + Core::Ref sibling_; +}; + } // namespace Search #endif // TRACEBACK_HH From 5fcfff74b71d76dacdb26b8a38ea9d4a8bd400b1 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Tue, 4 Mar 2025 17:33:08 +0100 Subject: [PATCH 16/52] Make predecessor and sibling public members --- src/Search/Traceback.cc | 20 +++++--------------- src/Search/Traceback.hh | 13 +++---------- 2 files changed, 8 insertions(+), 25 deletions(-) diff --git a/src/Search/Traceback.cc b/src/Search/Traceback.cc index 10afcae4..a6c62f6f 100644 --- a/src/Search/Traceback.cc +++ b/src/Search/Traceback.cc @@ -86,14 +86,6 @@ LatticeTrace::LatticeTrace( Search::TracebackItem::Transit const& transit) : TracebackItem(p, t, s, transit), predecessor(pre), sibling() {} -Core::Ref LatticeTrace::getPredecessor() const { - return predecessor_.get(); -} - -Core::Ref LatticeTrace::getSibling() const { - return sibling_.get(); -} - void LatticeTrace::appendSiblingToChain(Core::Ref sibling) { if (sibling_) { sibling_.appendSiblingToChain(sibling); @@ -140,22 +132,22 @@ Core::Ref LatticeTrace::buildWordLattice(Core::Refset(finalState->id(), Lattice::WordBoundary(this->time)); // Perform depth-first search + Fsa::State *preState, currentState; while (not traceStack.empty()) { auto* trace = traceStack.top(); traceStack.pop(); // A trace on the stack already has an associated state - Fsa::State* currentState = stateMap[trace]; + currentState = stateMap[trace]; wordBoundaries->set(currentState->id(), Lattice::WordBoundary(trace->time)); // Iterate through siblings of current trace // All siblings share the same lattice state - for (auto arcTrace = trace; arcTrace != nullptr; arcTrace = arcTrace->getSibling().get()) { + for (auto arcTrace = trace; arcTrace; arcTrace = arcTrace->sibling) { // For current sibling, get its predecessor, create a state for that predecessor // and connect it to the current state. - auto* preTrace = arcTrace->getPredecessor().get(); - Fsa::State* preState; - ScoreVector scores = trace->score; + auto* preTrace = arcTrace->predecessor.get(); + ScoreVector scores = trace->score; if (preTrace == nullptr) { // If trace has no predecessor, it gets connected to the initial state preState = initialState; @@ -185,6 +177,4 @@ Core::Ref LatticeTrace::buildWordLattice(Core::Ref predecessor; + Core::Ref sibling; + LatticeTrace(Core::Ref const& pre, Bliss::LemmaPronunciation const* p, Speech::TimeframeIndex t, ScoreVector s, Transit const& transit); - /* - * Getter functions - */ - LatticeTrace* getPredecessor() const; - LatticeTrace* getSibling() const; - /* * Append sibling chain to the end of the own sibling chain * Example: If we have sibling chains @@ -138,10 +135,6 @@ public: * is represented as the initial lattice state. */ Core::Ref buildWordLattice(Core::Ref lexicon) const; - -private: - Core::Ref predecessor_; - Core::Ref sibling_; }; } // namespace Search From 3152300c3da29d31d5d9d98b88a25817f9acaaf4 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Tue, 4 Mar 2025 20:41:52 +0100 Subject: [PATCH 17/52] Look for initial trace instead of associating empty trace with initial state --- src/Search/Traceback.cc | 94 +++++++++++++++++++++++++++-------------- src/Search/Traceback.hh | 35 ++++++++++++--- 2 files changed, 91 insertions(+), 38 deletions(-) diff --git a/src/Search/Traceback.cc b/src/Search/Traceback.cc index a6c62f6f..7e3c44bf 100644 --- a/src/Search/Traceback.cc +++ b/src/Search/Traceback.cc @@ -79,27 +79,30 @@ Lattice::WordLatticeRef Traceback::wordLattice(Core::Ref l } LatticeTrace::LatticeTrace( - Core::Ref const& pre, - const Bliss::LemmaPronunciation* p, - Speech::TimeframeIndex t, - ScoreVector s, + Core::Ref const& predecessor, + const Bliss::LemmaPronunciation* pronunciation, + Speech::TimeframeIndex timeframe, + ScoreVector scores, Search::TracebackItem::Transit const& transit) - : TracebackItem(p, t, s, transit), predecessor(pre), sibling() {} + : TracebackItem(pronunciation, timeframe, scores, transit), predecessor(predecessor), sibling() {} -void LatticeTrace::appendSiblingToChain(Core::Ref sibling) { - if (sibling_) { - sibling_.appendSiblingToChain(sibling); +LatticeTrace::LatticeTrace(Speech::TimeframeIndex timeframe, ScoreVector scores, const Transit& transit) + : TracebackItem(0, timeframe, scores, transit), predecessor(), sibling() {} + +void LatticeTrace::appendSiblingToChain(Core::Ref newSibling) { + if (sibling) { + sibling->appendSiblingToChain(newSibling); } else { - sibling_ = sibling; + sibling = newSibling; } } -Core::Ref LatticeTrace::getTraceback() const { +Core::Ref LatticeTrace::performTraceback() const { Core::Ref traceback; if (predecessor) { - traceback = predecessor->getTraceback(); + traceback = predecessor->performTraceback(); } else { traceback = Core::ref(new Traceback()); @@ -110,7 +113,7 @@ Core::Ref LatticeTrace::getTraceback() const { return traceback; } -Core::Ref LatticeTrace::buildWordLattice(Core::Ref lexicon) { +Core::Ref LatticeTrace::buildWordLattice(Core::Ref lexicon) const { // use default LemmaAlphabet mode of StandardWordLattice Core::Ref result(new Lattice::StandardWordLattice(lexicon)); Core::Ref wordBoundaries(new Lattice::WordBoundaries); @@ -119,8 +122,8 @@ Core::Ref LatticeTrace::buildWordLattice(Core::Ref stateMap; // Create an initial State at time 0 which represents empty predecessors - Fsa::State* initialState = result->initialState(); - wordBoundaries->set(initialState->id(), Lattice::WordBoundary(0)); + Fsa::State* initialState = result->initialState(); + Core::Ref initialTrace; // Stack for depth-first search through traces of all hypotheses in the beam std::stack traceStack; @@ -129,47 +132,51 @@ Core::Ref LatticeTrace::buildWordLattice(Core::ReffinalState(); stateMap[this] = finalState; traceStack.push(this); - wordBoundaries->set(finalState->id(), Lattice::WordBoundary(this->time)); // Perform depth-first search - Fsa::State *preState, currentState; + Fsa::State* preState; + Fsa::State* currentState; while (not traceStack.empty()) { - auto* trace = traceStack.top(); + const auto* trace = traceStack.top(); traceStack.pop(); // A trace on the stack already has an associated state currentState = stateMap[trace]; - wordBoundaries->set(currentState->id(), Lattice::WordBoundary(trace->time)); + wordBoundaries->set(currentState->id(), Lattice::WordBoundary(trace->time, trace->transit)); // Iterate through siblings of current trace // All siblings share the same lattice state - for (auto arcTrace = trace; arcTrace; arcTrace = arcTrace->sibling) { + for (auto arcTrace = trace; arcTrace; arcTrace = arcTrace->sibling.get()) { // For current sibling, get its predecessor, create a state for that predecessor // and connect it to the current state. - auto* preTrace = arcTrace->predecessor.get(); - ScoreVector scores = trace->score; - if (preTrace == nullptr) { - // If trace has no predecessor, it gets connected to the initial state - preState = initialState; - } - else { + auto const preTrace = arcTrace->predecessor; + verify(preTrace); + + if (preTrace->predecessor) { // If trace has a predecessor, get or create a state for it. Arc score // is difference between trace scores from predecessor to current. - scores -= preTrace->score; - if (stateMap.find(preTrace) == stateMap.end()) { - preState = result->newState(); - stateMap[preTrace] = preState; - traceStack.push(preTrace); + if (stateMap.find(preTrace.get()) == stateMap.end()) { + preState = result->newState(); + stateMap[preTrace.get()] = preState; + traceStack.push(preTrace.get()); } else { preState = stateMap[preTrace.get()]; } } + else { + // If trace has no predecessor, it gets connected to the initial state + preState = initialState; + initialTrace = preTrace; + } // Create arc from predecessor state to current state + ScoreVector scores = trace->score - preTrace->score; result->newArc(preState, currentState, arcTrace->pronunciation, scores.acoustic, scores.lm); } } + verify(initialTrace); + wordBoundaries->set(initialState->id(), Lattice::WordBoundary(initialTrace->time, initialTrace->transit)); result->setWordBoundaries(wordBoundaries); result->addAcyclicProperty(); @@ -177,4 +184,29 @@ Core::Ref LatticeTrace::buildWordLattice(Core::Ref phi) const { + performTraceback()->write(os, phi); +} + +void LatticeTrace::getLemmaSequence(std::vector& lemmaSequence) const { + if (predecessor) { + predecessor->getLemmaSequence(lemmaSequence); + } + if (pronunciation) { + lemmaSequence.push_back(const_cast(pronunciation->lemma())); + } +} + +u32 LatticeTrace::wordCount() const { + u32 count = 0; + if (pronunciation) { + ++count; + } + if (predecessor) { + count += predecessor->wordCount(); + } + + return count; +} + } // namespace Search diff --git a/src/Search/Traceback.hh b/src/Search/Traceback.hh index fc648c88..07084cbc 100644 --- a/src/Search/Traceback.hh +++ b/src/Search/Traceback.hh @@ -105,12 +105,14 @@ public: Core::Ref predecessor; Core::Ref sibling; - LatticeTrace(Core::Ref const& pre, - Bliss::LemmaPronunciation const* p, - Speech::TimeframeIndex t, - ScoreVector s, + LatticeTrace(Core::Ref const& predecessor, + Bliss::LemmaPronunciation const* pronunciation, + Speech::TimeframeIndex timeframe, + ScoreVector scores, Transit const& transit); + LatticeTrace(Speech::TimeframeIndex timeframe, ScoreVector scores, const Transit& transit); + /* * Append sibling chain to the end of the own sibling chain * Example: If we have sibling chains @@ -121,7 +123,7 @@ public: * * A -> B -> C -> D -> E */ - void appendSiblingToChain(Core::Ref sibling); + void appendSiblingToChain(Core::Ref newSibling); /* * Perform best-predecessor traceback. @@ -131,10 +133,29 @@ public: /* * Build a word lattice from a traces. The given trace will be represent the final lattice - * state and it is traced back along predecessors and siblings until ending up at the empty predecessor which - * is represented as the initial lattice state. + * state and it is traced back along predecessors and siblings until ending up at a trace with empty predecessor + * which represents the initial state. + * + * This assumes that all paths lead back to a single initial trace with empty predecessor. + * It's also assumed that this trace itself does have a predecessor, i.e. initial and final + * state in the lattice are different. */ Core::Ref buildWordLattice(Core::Ref lexicon) const; + + /* + * Write valid pronunciations of associated traceback to output stream. + */ + void write(std::ostream& os, Core::Ref phi) const; + + /* + * Collect lemmas of valid pronunciations of associated traceback into `lemmaSequence`. + */ + void getLemmaSequence(std::vector& lemmaSequence) const; + + /* + * Count number of items with valid pronunciations along associated traceback. + */ + u32 wordCount() const; }; } // namespace Search From 0b676f9acbaf186e186c88c7a53b96ae8b9b1c30 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Tue, 4 Mar 2025 21:20:59 +0100 Subject: [PATCH 18/52] Remove redundant includes --- src/Search/Traceback.cc | 2 -- src/Search/Traceback.hh | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/Search/Traceback.cc b/src/Search/Traceback.cc index 7e3c44bf..641a6bcf 100644 --- a/src/Search/Traceback.cc +++ b/src/Search/Traceback.cc @@ -14,8 +14,6 @@ */ #include "Traceback.hh" -#include -#include #include namespace Search { diff --git a/src/Search/Traceback.hh b/src/Search/Traceback.hh index 07084cbc..59ea89f2 100644 --- a/src/Search/Traceback.hh +++ b/src/Search/Traceback.hh @@ -18,7 +18,7 @@ #include #include #include -#include +#include #include namespace Search { From 159fbd8078a0927c2adc63776984b90826627f2c Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Wed, 5 Mar 2025 01:48:11 +0100 Subject: [PATCH 19/52] Add assertions for assumptions in lattice building --- src/Search/Traceback.cc | 6 ++++++ src/Search/Traceback.hh | 4 ++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/Search/Traceback.cc b/src/Search/Traceback.cc index 641a6bcf..a5c139a7 100644 --- a/src/Search/Traceback.cc +++ b/src/Search/Traceback.cc @@ -14,6 +14,7 @@ */ #include "Traceback.hh" + #include namespace Search { @@ -112,6 +113,9 @@ Core::Ref LatticeTrace::performTraceback() const { } Core::Ref LatticeTrace::buildWordLattice(Core::Ref lexicon) const { + // If predecessor Ref is empty the lattice would only have one state + require(predecessor); + // use default LemmaAlphabet mode of StandardWordLattice Core::Ref result(new Lattice::StandardWordLattice(lexicon)); Core::Ref wordBoundaries(new Lattice::WordBoundaries); @@ -164,6 +168,8 @@ Core::Ref LatticeTrace::buildWordLattice(Core::Ref buildWordLattice(Core::Ref lexicon) const; From 0577e79e32a2401a3509c779e2666ac1b3527629 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Wed, 5 Mar 2025 14:25:08 +0100 Subject: [PATCH 20/52] Remove wrong assertion --- src/Search/Traceback.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Search/Traceback.cc b/src/Search/Traceback.cc index a5c139a7..4f559787 100644 --- a/src/Search/Traceback.cc +++ b/src/Search/Traceback.cc @@ -169,7 +169,6 @@ Core::Ref LatticeTrace::buildWordLattice(Core::Ref Date: Wed, 5 Mar 2025 14:59:57 +0100 Subject: [PATCH 21/52] Remove initial item in `performTraceback` --- src/Search/Traceback.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Search/Traceback.cc b/src/Search/Traceback.cc index 4f559787..48043d4c 100644 --- a/src/Search/Traceback.cc +++ b/src/Search/Traceback.cc @@ -105,7 +105,6 @@ Core::Ref LatticeTrace::performTraceback() const { } else { traceback = Core::ref(new Traceback()); - traceback->push_back(TracebackItem(0, 0, {0, 0}, {})); } traceback->push_back(*this); From d393c7e04d30eae507e03381b1b6ba723ae7ea67 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Wed, 5 Mar 2025 15:06:16 +0100 Subject: [PATCH 22/52] Fix arc scores --- src/Search/Traceback.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Search/Traceback.cc b/src/Search/Traceback.cc index 48043d4c..87e5e9b6 100644 --- a/src/Search/Traceback.cc +++ b/src/Search/Traceback.cc @@ -173,7 +173,7 @@ Core::Ref LatticeTrace::buildWordLattice(Core::Refscore - preTrace->score; + ScoreVector scores = arcTrace->score - preTrace->score; result->newArc(preState, currentState, arcTrace->pronunciation, scores.acoustic, scores.lm); } } From d67cf45caf71e335d3e85acc008d4f04b7ec6f9c Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Wed, 5 Mar 2025 20:10:45 +0100 Subject: [PATCH 23/52] Update traceback/lattice building logic --- .../LexiconfreeTimesyncBeamSearch.cc | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc index 0e621bea..901e31e5 100644 --- a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc @@ -154,15 +154,19 @@ void LexiconfreeTimesyncBeamSearch::putFeatures(std::shared_ptr con } Core::Ref LexiconfreeTimesyncBeamSearch::getCurrentBestTraceback() const { - return beam_.front().trace->getTraceback(); + return beam_.front().trace->performTraceback(); } Core::Ref LexiconfreeTimesyncBeamSearch::getCurrentBestWordLattice() const { - std::vector> traces; - for (auto const& hyp : beam_) { - traces.push_back(hyp.trace); + LatticeTrace endTrace(beam_.front().trace, 0, beam_.front().trace->time + 1, beam_.front().trace->score, {}); + + for (size_t hypIdx = 1ul; hypIdx < beam_.size(); ++hypIdx) { + auto& hyp = beam_[hypIdx]; + auto siblingTrace = Core::ref(new LatticeTrace(hyp.trace, 0, hyp.trace->time, hyp.trace->score, {})); + endTrace.appendSiblingToChain(siblingTrace); } - return buildWordLatticeFromTraces(traces, lexicon_); + + return endTrace.buildWordLattice(lexicon_); } void LexiconfreeTimesyncBeamSearch::resetStatistics() { @@ -392,7 +396,7 @@ LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis() : scoringContext(), currentToken(Core::Type::max), score(0.0), - trace() {} + trace(Core::ref(new LatticeTrace(0, {0, 0}, {}))) {} LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( LexiconfreeTimesyncBeamSearch::LabelHypothesis const& base, @@ -417,10 +421,6 @@ LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( break; case Nn::LabelScorer::LABEL_LOOP: case Nn::LabelScorer::BLANK_LOOP: - // `base.trace` is empty in the first step but at that point only `INITIAL_BLANK` and `INITIAL_LABEL` transitions can happen. - // Afterwards, `base.trace` should always be non-empty. - verify(base.trace); - // Copy base trace and update it trace = Core::ref(new LatticeTrace(*base.trace)); trace->sibling = {}; @@ -434,7 +434,7 @@ std::string LexiconfreeTimesyncBeamSearch::LabelHypothesis::toString() const { std::stringstream ss; ss << "Score: " << score << ", traceback: "; - auto traceback = trace->getTraceback(); + auto traceback = trace->performTraceback(); for (auto& item : *traceback) { if (item.pronunciation and item.pronunciation->lemma()) { From 54535e627f024e1b44840012f891b01f6ab2711a Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Wed, 5 Mar 2025 20:40:03 +0100 Subject: [PATCH 24/52] Make `elapsed` functions const --- src/Core/StopWatch.cc | 12 ++++++------ src/Core/StopWatch.hh | 10 +++++----- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/Core/StopWatch.cc b/src/Core/StopWatch.cc index 6e02cdae..09b655cf 100644 --- a/src/Core/StopWatch.cc +++ b/src/Core/StopWatch.cc @@ -46,32 +46,32 @@ void StopWatch::reset() { running_ = false; } -double StopWatch::elapsedSeconds() { +double StopWatch::elapsedSeconds() const { if (running_) { timeval endTime; double currentTime = 0; // in seconds // Note: This macro doesn't actually "stop" anything, it just writes into `endTime` and `currentTime` - TIMER_STOP(startTime_, endTime, currentTime); + TIMER_STOP(const_cast(startTime_), endTime, currentTime); return elapsedSeconds_ + currentTime; } return elapsedSeconds_; } -double StopWatch::elapsedCentiseconds() { +double StopWatch::elapsedCentiseconds() const { return elapsedSeconds() * 1e2; } -double StopWatch::elapsedMilliseconds() { +double StopWatch::elapsedMilliseconds() const { return elapsedSeconds() * 1e3; } -double StopWatch::elapsedMicroseconds() { +double StopWatch::elapsedMicroseconds() const { return elapsedSeconds() * 1e6; } -double StopWatch::elapsedNanoseconds() { +double StopWatch::elapsedNanoseconds() const { return elapsedSeconds() * 1e9; } diff --git a/src/Core/StopWatch.hh b/src/Core/StopWatch.hh index a0ed87cb..856363c4 100644 --- a/src/Core/StopWatch.hh +++ b/src/Core/StopWatch.hh @@ -46,11 +46,11 @@ public: * Getter functions to get the total elapsed time in different units. Includes the current interval * if the timer is running. */ - double elapsedSeconds(); - double elapsedCentiseconds(); - double elapsedMilliseconds(); - double elapsedMicroseconds(); - double elapsedNanoseconds(); + double elapsedSeconds() const; + double elapsedCentiseconds() const; + double elapsedMilliseconds() const; + double elapsedMicroseconds() const; + double elapsedNanoseconds() const; private: bool running_; From c86855e0942f1e499d09130fd698f13762a995c7 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Wed, 5 Mar 2025 20:50:44 +0100 Subject: [PATCH 25/52] Add RecognizerNodeV2 --- src/Flf/Makefile | 1 + src/Flf/NodeRegistration.hh | 18 +++ src/Flf/RecognizerV2.cc | 231 +++++++++++++++++++++++++++++++++ src/Flf/RecognizerV2.hh | 74 +++++++++++ src/Speech/ModelCombination.cc | 54 +++++--- src/Speech/ModelCombination.hh | 14 +- 6 files changed, 369 insertions(+), 23 deletions(-) create mode 100644 src/Flf/RecognizerV2.cc create mode 100644 src/Flf/RecognizerV2.hh diff --git a/src/Flf/Makefile b/src/Flf/Makefile index 99afeb53..62db8d3a 100644 --- a/src/Flf/Makefile +++ b/src/Flf/Makefile @@ -59,6 +59,7 @@ LIBSPRINTFLF_O = \ $(OBJDIR)/Prune.o \ $(OBJDIR)/PushForwardRescoring.o \ $(OBJDIR)/Recognizer.o \ + $(OBJDIR)/RecognizerV2.o \ $(OBJDIR)/IncrementalRecognizer.o \ $(OBJDIR)/Rescore.o \ $(OBJDIR)/RescoreLm.o \ diff --git a/src/Flf/NodeRegistration.hh b/src/Flf/NodeRegistration.hh index 8beb4e47..2b9d03c8 100644 --- a/src/Flf/NodeRegistration.hh +++ b/src/Flf/NodeRegistration.hh @@ -51,6 +51,7 @@ #include "Prune.hh" #include "PushForwardRescoring.hh" #include "Recognizer.hh" +#include "RecognizerV2.hh" #include "Rescale.hh" #include "Rescore.hh" #include "RescoreLm.hh" @@ -2145,6 +2146,23 @@ void registerNodeCreators(NodeFactory* factory) { " 0:lattice", &createRecognizerNode)); + factory->add( + NodeCreator( + "recognizer-v2", + "Second version of RASR recognizer.\n" + "Output are lattices in Flf format.\n" + "Much more minimalistic than the first recognizer node\n" + "and works with a `SearchAlgorithmV2` instead of\n" + "`SearchAlgorithm`. Performs recognition of the input segments\n" + "and sends the result lattices as outputs.\n" + "[*.network.recognizer-v2]\n" + "type = recognizer-v2\n" + "input:\n" + " 0:bliss-speech-segment\n" + "output:\n" + " 0:lattice", + &createRecognizerNodeV2)); + factory->add( NodeCreator( "incremental-recognizer", diff --git a/src/Flf/RecognizerV2.cc b/src/Flf/RecognizerV2.cc new file mode 100644 index 00000000..f8a26a02 --- /dev/null +++ b/src/Flf/RecognizerV2.cc @@ -0,0 +1,231 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "RecognizerV2.hh" +#include +#include +#include "Core/XmlStream.hh" +#include "LatticeHandler.hh" +#include "Module.hh" + +namespace Flf { + +NodeRef createRecognizerNodeV2(const std::string& name, const Core::Configuration& config) { + return NodeRef(new RecognizerNodeV2(name, config)); +} + +RecognizerNodeV2::RecognizerNodeV2(const std::string& name, const Core::Configuration& config) + : Node(name, config), + searchAlgorithm_(Search::Module::instance().createSearchAlgorithm(select("search-algorithm"))), + modelCombination_(config) { + Core::Configuration featureExtractionConfig(config, "feature-extraction"); + DataSourceRef dataSource = DataSourceRef(Speech::Module::instance().createDataSource(featureExtractionConfig)); + featureExtractor_ = SegmentwiseFeatureExtractorRef(new SegmentwiseFeatureExtractor(featureExtractionConfig, dataSource)); +} + +void RecognizerNodeV2::recognizeSegment(const Bliss::SpeechSegment* segment) { + if (!segment->orth().empty()) { + clog() << Core::XmlOpen("orth") + Core::XmlAttribute("source", "reference") + << segment->orth() + << Core::XmlClose("orth"); + } + + // Initialize recognizer and feature extractor + searchAlgorithm_->reset(); + searchAlgorithm_->enterSegment(); + + featureExtractor_->enterSegment(segment); + DataSourceRef dataSource = featureExtractor_->extractor(); + dataSource->initialize(const_cast(segment)); + FeatureRef feature; + dataSource->getData(feature); + Time startTime = feature->timestamp().startTime(); + Time endTime; + + auto timerStart = std::chrono::steady_clock::now(); + + // Loop over features and perform recognition + do { + searchAlgorithm_->putFeature(*feature->mainStream()); + endTime = feature->timestamp().endTime(); + } while (dataSource->getData(feature)); + + searchAlgorithm_->finishSegment(); + searchAlgorithm_->decodeManySteps(); + dataSource->finalize(); + featureExtractor_->leaveSegment(segment); + + // Result processing and logging + auto traceback = searchAlgorithm_->getCurrentBestTraceback(); + + auto lattice = buildLattice(searchAlgorithm_->getCurrentBestWordLattice(), segment->name()); + resultBuffer_ = std::make_pair(lattice, SegmentRef(new Flf::Segment(segment))); + + Core::XmlWriter& os(clog()); + os << Core::XmlOpen("traceback"); + traceback->write(os, modelCombination_.lexicon()->phonemeInventory()); + os << Core::XmlClose("traceback"); + + os << Core::XmlOpen("orth") + Core::XmlAttribute("source", "recognized"); + for (auto const& tracebackItem : *traceback) { + if (tracebackItem.pronunciation and tracebackItem.pronunciation->lemma()) { + os << tracebackItem.pronunciation->lemma()->preferredOrthographicForm() << Core::XmlBlank(); + } + } + os << Core::XmlClose("orth"); + + auto timerEnd = std::chrono::steady_clock::now(); + double duration = std::chrono::duration(timerEnd - timerStart).count(); + double signalDuration = (endTime - startTime) * 1000.; // convert duration to ms + + clog() << Core::XmlOpen("flf-recognizer-time") + Core::XmlAttribute("unit", "milliseconds") << duration << Core::XmlClose("flf-recognizer-time"); + clog() << Core::XmlOpen("flf-recognizer-rtf") << (duration / signalDuration) << Core::XmlClose("flf-recognizer-rtf"); +} + +void RecognizerNodeV2::work() { + clog() << Core::XmlOpen("layer") + Core::XmlAttribute("name", name); + recognizeSegment(static_cast(requestData(0))); + clog() << Core::XmlClose("layer"); +} + +ConstLatticeRef RecognizerNodeV2::buildLattice(Core::Ref latticeAdaptor, std::string segmentName) { + auto semiring = Semiring::create(Fsa::SemiringTypeTropical, 2); + semiring->setKey(0, "am"); + semiring->setScale(0, 1.0); + semiring->setKey(1, "lm"); + semiring->setScale(1, modelCombination_.languageModel()->scale()); + + auto sentenceEndLabel = Fsa::Epsilon; + const Bliss::Lemma* specialSentenceEndLemma = modelCombination_.lexicon()->specialLemma("sentence-end"); + if (specialSentenceEndLemma and specialSentenceEndLemma->nPronunciations() > 0) { + sentenceEndLabel = specialSentenceEndLemma->pronunciations().first->id(); + } + + Flf::LatticeHandler* handler = Flf::Module::instance().createLatticeHandler(config); + handler->setLexicon(Lexicon::us()); + if (latticeAdaptor->empty()) { + return ConstLatticeRef(); + } + ::Lattice::ConstWordLatticeRef lattice = latticeAdaptor->wordLattice(handler); + Core::Ref boundaries = lattice->wordBoundaries(); + Fsa::ConstAutomatonRef amFsa = lattice->part(::Lattice::WordLattice::acousticFsa); + Fsa::ConstAutomatonRef lmFsa = lattice->part(::Lattice::WordLattice::lmFsa); + require_(Fsa::isAcyclic(amFsa) && Fsa::isAcyclic(lmFsa)); + + StaticBoundariesRef b = StaticBoundariesRef(new StaticBoundaries); + StaticLatticeRef s = StaticLatticeRef(new StaticLattice); + s->setType(Fsa::TypeAcceptor); + s->setProperties(Fsa::PropertyAcyclic | PropertyCrossWord, Fsa::PropertyAll); + s->setInputAlphabet(modelCombination_.lexicon()->lemmaPronunciationAlphabet()); + s->setSemiring(semiring); + s->setDescription(Core::form("recog(%s)", segmentName.c_str())); + s->setBoundaries(ConstBoundariesRef(b)); + s->setInitialStateId(0); + + Time timeOffset = (*boundaries)[amFsa->initialStateId()].time(); + + Fsa::Stack stateStack; + Core::Vector sidMap(amFsa->initialStateId() + 1, Fsa::InvalidStateId); + sidMap[amFsa->initialStateId()] = 0; + stateStack.push_back(amFsa->initialStateId()); + Fsa::StateId nextSid = 2; + Time finalTime = 0; + while (not stateStack.isEmpty()) { + Fsa::StateId sid = stateStack.pop(); + verify(sid < sidMap.size()); + const ::Lattice::WordBoundary& boundary((*boundaries)[sid]); + Fsa::ConstStateRef amSr = amFsa->getState(sid); + Fsa::ConstStateRef lmSr = lmFsa->getState(sid); + State* sp = new State(sidMap[sid]); + s->setState(sp); + b->set(sp->id(), Boundary(boundary.time() - timeOffset, + Boundary::Transit(boundary.transit().final, boundary.transit().initial))); + if (amSr->isFinal()) { + auto scores = semiring->create(); + scores->set(0, amSr->weight()); + scores->set(1, static_cast(lmSr->weight()) / semiring->scale(1)); + sp->newArc(1, scores, sentenceEndLabel); + finalTime = std::max(finalTime, boundary.time() - timeOffset); + } + for (Fsa::State::const_iterator am_a = amSr->begin(), lm_a = lmSr->begin(); (am_a != amSr->end()) && (lm_a != lmSr->end()); ++am_a, ++lm_a) { + sidMap.grow(am_a->target(), Fsa::InvalidStateId); + if (sidMap[am_a->target()] == Fsa::InvalidStateId) { + sidMap[am_a->target()] = nextSid++; + stateStack.push(am_a->target()); + } + Fsa::ConstStateRef targetAmSr = amFsa->getState(am_a->target()); + Fsa::ConstStateRef targetLmSr = amFsa->getState(lm_a->target()); + if (targetAmSr->isFinal() && targetLmSr->isFinal()) { + if (am_a->input() == Fsa::Epsilon) { + auto scores = semiring->create(); + scores->set(0, am_a->weight()); + scores->set(1, static_cast(lm_a->weight()) / semiring->scale(1)); + scores->add(0, Score(targetAmSr->weight())); + scores->add(1, Score(targetLmSr->weight()) / semiring->scale(1)); + sp->newArc(1, scores, sentenceEndLabel); + } + else { + auto scores = semiring->create(); + scores->set(0, am_a->weight()); + scores->set(1, static_cast(lm_a->weight()) / semiring->scale(1)); + sp->newArc(sidMap[am_a->target()], scores, am_a->input()); + } + } + else { + auto scores = semiring->create(); + scores->set(0, am_a->weight()); + scores->set(1, static_cast(lm_a->weight()) / semiring->scale(1)); + sp->newArc(sidMap[am_a->target()], scores, am_a->input()); + } + } + } + State* sp = new State(1); + sp->setFinal(semiring->clone(semiring->one())); + s->setState(sp); + b->set(sp->id(), Boundary(finalTime)); + return s; +} + +void RecognizerNodeV2::init(std::vector const& arguments) { + modelCombination_.build(searchAlgorithm_->requiredModelCombination(), searchAlgorithm_->requiredAcousticModel(), Lexicon::us()); + searchAlgorithm_->setModelCombination(modelCombination_); + if (not connected(0)) { + criticalError("Speech segment at port 1 required"); + } +} + +void RecognizerNodeV2::sync() { + resultBuffer_.first.reset(); + resultBuffer_.second.reset(); +} + +void RecognizerNodeV2::finalize() { + searchAlgorithm_->reset(); +} + +ConstSegmentRef RecognizerNodeV2::sendSegment(RecognizerNodeV2::Port to) { + if (!resultBuffer_.second) { + work(); + } + return resultBuffer_.second; +} + +ConstLatticeRef RecognizerNodeV2::sendLattice(RecognizerNodeV2::Port to) { + if (!resultBuffer_.first) { + work(); + } + return resultBuffer_.first; +} + +} // namespace Flf diff --git a/src/Flf/RecognizerV2.hh b/src/Flf/RecognizerV2.hh new file mode 100644 index 00000000..ea6b04e2 --- /dev/null +++ b/src/Flf/RecognizerV2.hh @@ -0,0 +1,74 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef RECOGNIZER_V2_HH +#define RECOGNIZER_V2_HH + +#include +#include +#include +#include +#include "Network.hh" +#include "SegmentwiseSpeechProcessor.hh" +#include "Speech/ModelCombination.hh" + +namespace Flf { + +NodeRef createRecognizerNodeV2(std::string const& name, Core::Configuration const& config); + +/* + * Node to run recognition on speech segments using a `SearchAlgorithmV2` internally. + */ +class RecognizerNodeV2 : public Node { +public: + RecognizerNodeV2(std::string const& name, Core::Configuration const& config); + + virtual ~RecognizerNodeV2() { + delete searchAlgorithm_; + } + + // Inherited methods + virtual void init(std::vector const& arguments) override; + virtual void sync() override; + virtual void finalize() override; + + virtual ConstSegmentRef sendSegment(Port to) override; + virtual ConstLatticeRef sendLattice(Port to) override; + +private: + /* + * Perform recognition of `segment` using `searchAlgorithm_` and store the result in `resultBuffer_` + */ + void recognizeSegment(const Bliss::SpeechSegment* segment); + + /* + * Requests input segment and runs recognition on it + */ + void work(); + + /* + * Convert an output lattice from `searchAlgorithm_` to an Flf lattice + */ + ConstLatticeRef buildLattice(Core::Ref latticeAdaptor, std::string segmentName); + + std::pair resultBuffer_; + + Search::SearchAlgorithmV2* searchAlgorithm_; + Speech::ModelCombination modelCombination_; + SegmentwiseFeatureExtractorRef featureExtractor_; +}; + +} // namespace Flf + +#endif // RECOGNIZER_V2_HH diff --git a/src/Speech/ModelCombination.cc b/src/Speech/ModelCombination.cc index 075c9d8c..d417327e 100644 --- a/src/Speech/ModelCombination.cc +++ b/src/Speech/ModelCombination.cc @@ -16,6 +16,7 @@ #include #include #include +#include "Am/AcousticModel.hh" using namespace Speech; @@ -32,16 +33,43 @@ const Core::ParameterFloat ModelCombination::paramPronunciationScale( ModelCombination::ModelCombination(const Core::Configuration& c, Mode mode, - Am::AcousticModel::Mode acousticModelMode) + Am::AcousticModel::Mode acousticModelMode, + Bliss::LexiconRef lexicon) : Core::Component(c), Mc::Component(c), pronunciationScale_(0) { - setLexicon(Bliss::Lexicon::create(select("lexicon"))); - if (!lexicon_) - criticalError("failed to initialize the lexicon"); + setPronunciationScale(paramPronunciationScale(c)); + build(mode, acousticModelMode, lexicon); +} - /*! \todo Scalable lexicon not implemented yet */ +ModelCombination::ModelCombination(const Core::Configuration& c, + Bliss::LexiconRef lexicon, + Core::Ref acousticModel, + Core::Ref languageModel) + : Core::Component(c), Mc::Component(c), pronunciationScale_(0) { setPronunciationScale(paramPronunciationScale(c)); + setLexicon(lexicon); + setAcousticModel(acousticModel); + setLanguageModel(languageModel); +} + +ModelCombination::~ModelCombination() {} + +void ModelCombination::build(Mode mode, + Am::AcousticModel::Mode acousticModelMode, + Bliss::LexiconRef lexicon) { + if (lexicon) { + setLexicon(lexicon); + log() << "Set lexicon in ModelCombination"; + } + else { + log() << "Create lexicon in ModelCombination"; + setLexicon(Bliss::Lexicon::create(select("lexicon"))); + } + + if (!lexicon_) { + criticalError("failed to initialize the lexicon"); + } if (mode & useAcousticModel) { setAcousticModel(Am::Module::instance().createAcousticModel( @@ -58,24 +86,12 @@ ModelCombination::ModelCombination(const Core::Configuration& c, if (mode & useLabelScorer) { setLabelScorer(Nn::Module::instance().labelScorerFactory().createLabelScorer(select("label-scorer"))); - if (!labelScorer_) + if (!labelScorer_) { criticalError("failed to initialize label scorer"); + } } } -ModelCombination::ModelCombination(const Core::Configuration& c, - Bliss::LexiconRef lexicon, - Core::Ref acousticModel, - Core::Ref languageModel) - : Core::Component(c), Mc::Component(c), pronunciationScale_(0) { - setPronunciationScale(paramPronunciationScale(c)); - setLexicon(lexicon); - setAcousticModel(acousticModel); - setLanguageModel(languageModel); -} - -ModelCombination::~ModelCombination() {} - void ModelCombination::setLexicon(Bliss::LexiconRef lexicon) { lexicon_ = lexicon; } diff --git a/src/Speech/ModelCombination.hh b/src/Speech/ModelCombination.hh index 27466749..bfbcdb68 100644 --- a/src/Speech/ModelCombination.hh +++ b/src/Speech/ModelCombination.hh @@ -23,7 +23,6 @@ #include #include - namespace Speech { /** Combination of a lexicon, an acoustic model or label scorer, and a language model. @@ -65,11 +64,14 @@ protected: public: ModelCombination(const Core::Configuration&, Mode = complete, - Am::AcousticModel::Mode = Am::AcousticModel::complete); + Am::AcousticModel::Mode = Am::AcousticModel::complete, + Bliss::LexiconRef = Bliss::LexiconRef()); ModelCombination(const Core::Configuration&, Bliss::LexiconRef, Core::Ref, Core::Ref); virtual ~ModelCombination(); + void build(Mode = complete, Am::AcousticModel::Mode = Am::AcousticModel::complete, Bliss::LexiconRef = Bliss::LexiconRef()); + void getDependencies(Core::DependencySet&) const; Bliss::LexiconRef lexicon() const { @@ -88,8 +90,12 @@ public: } void setLanguageModel(Core::Ref); - void setLabelScorer(Core::Ref ls) { labelScorer_ = ls; } - Core::Ref labelScorer() const { return labelScorer_; } + void setLabelScorer(Core::Ref ls) { + labelScorer_ = ls; + } + Core::Ref labelScorer() const { + return labelScorer_; + } }; typedef Core::Ref ModelCombinationRef; From 1dc47e6e6d5dd30fe3db23019d8f67f48cb19ba8 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Thu, 6 Mar 2025 15:24:42 +0100 Subject: [PATCH 26/52] Make `modelCombination_` a ref + some formatting --- src/Flf/RecognizerV2.cc | 18 ++++++++----- src/Flf/RecognizerV2.hh | 6 ++--- src/Speech/ModelCombination.cc | 49 +++++++++++++++++----------------- src/Speech/ModelCombination.hh | 15 +++++++---- 4 files changed, 49 insertions(+), 39 deletions(-) diff --git a/src/Flf/RecognizerV2.cc b/src/Flf/RecognizerV2.cc index f8a26a02..bcf7be5b 100644 --- a/src/Flf/RecognizerV2.cc +++ b/src/Flf/RecognizerV2.cc @@ -28,7 +28,7 @@ NodeRef createRecognizerNodeV2(const std::string& name, const Core::Configuratio RecognizerNodeV2::RecognizerNodeV2(const std::string& name, const Core::Configuration& config) : Node(name, config), searchAlgorithm_(Search::Module::instance().createSearchAlgorithm(select("search-algorithm"))), - modelCombination_(config) { + modelCombination_() { Core::Configuration featureExtractionConfig(config, "feature-extraction"); DataSourceRef dataSource = DataSourceRef(Speech::Module::instance().createDataSource(featureExtractionConfig)); featureExtractor_ = SegmentwiseFeatureExtractorRef(new SegmentwiseFeatureExtractor(featureExtractionConfig, dataSource)); @@ -74,7 +74,7 @@ void RecognizerNodeV2::recognizeSegment(const Bliss::SpeechSegment* segment) { Core::XmlWriter& os(clog()); os << Core::XmlOpen("traceback"); - traceback->write(os, modelCombination_.lexicon()->phonemeInventory()); + traceback->write(os, modelCombination_->lexicon()->phonemeInventory()); os << Core::XmlClose("traceback"); os << Core::XmlOpen("orth") + Core::XmlAttribute("source", "recognized"); @@ -104,10 +104,10 @@ ConstLatticeRef RecognizerNodeV2::buildLattice(Core::RefsetKey(0, "am"); semiring->setScale(0, 1.0); semiring->setKey(1, "lm"); - semiring->setScale(1, modelCombination_.languageModel()->scale()); + semiring->setScale(1, modelCombination_->languageModel()->scale()); auto sentenceEndLabel = Fsa::Epsilon; - const Bliss::Lemma* specialSentenceEndLemma = modelCombination_.lexicon()->specialLemma("sentence-end"); + const Bliss::Lemma* specialSentenceEndLemma = modelCombination_->lexicon()->specialLemma("sentence-end"); if (specialSentenceEndLemma and specialSentenceEndLemma->nPronunciations() > 0) { sentenceEndLabel = specialSentenceEndLemma->pronunciations().first->id(); } @@ -127,7 +127,7 @@ ConstLatticeRef RecognizerNodeV2::buildLattice(Core::RefsetType(Fsa::TypeAcceptor); s->setProperties(Fsa::PropertyAcyclic | PropertyCrossWord, Fsa::PropertyAll); - s->setInputAlphabet(modelCombination_.lexicon()->lemmaPronunciationAlphabet()); + s->setInputAlphabet(modelCombination_->lexicon()->lemmaPronunciationAlphabet()); s->setSemiring(semiring); s->setDescription(Core::form("recog(%s)", segmentName.c_str())); s->setBoundaries(ConstBoundariesRef(b)); @@ -198,8 +198,12 @@ ConstLatticeRef RecognizerNodeV2::buildLattice(Core::Ref const& arguments) { - modelCombination_.build(searchAlgorithm_->requiredModelCombination(), searchAlgorithm_->requiredAcousticModel(), Lexicon::us()); - searchAlgorithm_->setModelCombination(modelCombination_); + modelCombination_ = Core::ref(new Speech::ModelCombination( + config, + searchAlgorithm_->requiredModelCombination(), + searchAlgorithm_->requiredAcousticModel(), + Lexicon::us())); + searchAlgorithm_->setModelCombination(*modelCombination_); if (not connected(0)) { criticalError("Speech segment at port 1 required"); } diff --git a/src/Flf/RecognizerV2.hh b/src/Flf/RecognizerV2.hh index ea6b04e2..3898216f 100644 --- a/src/Flf/RecognizerV2.hh +++ b/src/Flf/RecognizerV2.hh @@ -64,9 +64,9 @@ private: std::pair resultBuffer_; - Search::SearchAlgorithmV2* searchAlgorithm_; - Speech::ModelCombination modelCombination_; - SegmentwiseFeatureExtractorRef featureExtractor_; + Search::SearchAlgorithmV2* searchAlgorithm_; + Core::Ref modelCombination_; + SegmentwiseFeatureExtractorRef featureExtractor_; }; } // namespace Flf diff --git a/src/Speech/ModelCombination.cc b/src/Speech/ModelCombination.cc index d417327e..7bd385be 100644 --- a/src/Speech/ModelCombination.cc +++ b/src/Speech/ModelCombination.cc @@ -39,25 +39,7 @@ ModelCombination::ModelCombination(const Core::Configuration& c, Mc::Component(c), pronunciationScale_(0) { setPronunciationScale(paramPronunciationScale(c)); - build(mode, acousticModelMode, lexicon); -} - -ModelCombination::ModelCombination(const Core::Configuration& c, - Bliss::LexiconRef lexicon, - Core::Ref acousticModel, - Core::Ref languageModel) - : Core::Component(c), Mc::Component(c), pronunciationScale_(0) { - setPronunciationScale(paramPronunciationScale(c)); - setLexicon(lexicon); - setAcousticModel(acousticModel); - setLanguageModel(languageModel); -} -ModelCombination::~ModelCombination() {} - -void ModelCombination::build(Mode mode, - Am::AcousticModel::Mode acousticModelMode, - Bliss::LexiconRef lexicon) { if (lexicon) { setLexicon(lexicon); log() << "Set lexicon in ModelCombination"; @@ -68,30 +50,45 @@ void ModelCombination::build(Mode mode, } if (!lexicon_) { - criticalError("failed to initialize the lexicon"); + criticalError("Failed to initialize the lexicon"); } if (mode & useAcousticModel) { setAcousticModel(Am::Module::instance().createAcousticModel( select("acoustic-model"), lexicon_, acousticModelMode)); - if (!acousticModel_) - criticalError("failed to initialize the acoustic model"); + if (!acousticModel_) { + criticalError("Failed to initialize the acoustic model"); + } } if (mode & useLanguageModel) { setLanguageModel(Lm::Module::instance().createScaledLanguageModel(select("lm"), lexicon_)); - if (!languageModel_) - criticalError("failed to initialize language model"); + if (!languageModel_) { + criticalError("Failed to initialize language model"); + } } if (mode & useLabelScorer) { setLabelScorer(Nn::Module::instance().labelScorerFactory().createLabelScorer(select("label-scorer"))); if (!labelScorer_) { - criticalError("failed to initialize label scorer"); + criticalError("Failed to initialize label scorer"); } } } +ModelCombination::ModelCombination(const Core::Configuration& c, + Bliss::LexiconRef lexicon, + Core::Ref acousticModel, + Core::Ref languageModel) + : Core::Component(c), Mc::Component(c), pronunciationScale_(0) { + setPronunciationScale(paramPronunciationScale(c)); + setLexicon(lexicon); + setAcousticModel(acousticModel); + setLanguageModel(languageModel); +} + +ModelCombination::~ModelCombination() {} + void ModelCombination::setLexicon(Bliss::LexiconRef lexicon) { lexicon_ = lexicon; } @@ -108,6 +105,10 @@ void ModelCombination::setLanguageModel(Core::Ref langu languageModel_->setParentScale(scale()); } +void ModelCombination::setLabelScorer(Core::Ref ls) { + labelScorer_ = ls; +} + void ModelCombination::distributeScaleUpdate(const Mc::ScaleUpdate& scaleUpdate) { if (lexicon_) { Mm::Score scale; diff --git a/src/Speech/ModelCombination.hh b/src/Speech/ModelCombination.hh index bfbcdb68..d94c9d1f 100644 --- a/src/Speech/ModelCombination.hh +++ b/src/Speech/ModelCombination.hh @@ -77,25 +77,30 @@ public: Bliss::LexiconRef lexicon() const { return lexicon_; } - void setLexicon(Bliss::LexiconRef); + + void setLexicon(Bliss::LexiconRef); + Mm::Score pronunciationScale() const { return pronunciationScale_ * scale(); } + Core::Ref acousticModel() const { return acousticModel_; } - void setAcousticModel(Core::Ref); + + void setAcousticModel(Core::Ref); + Core::Ref languageModel() const { return languageModel_; } + void setLanguageModel(Core::Ref); - void setLabelScorer(Core::Ref ls) { - labelScorer_ = ls; - } Core::Ref labelScorer() const { return labelScorer_; } + + void setLabelScorer(Core::Ref ls); }; typedef Core::Ref ModelCombinationRef; From 600999b06b8fb524a4cc0494e4d5bc89f6655397 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Thu, 6 Mar 2025 18:03:31 +0100 Subject: [PATCH 27/52] Better readable lattice building function --- src/Flf/RecognizerV2.cc | 123 +++++++++++++++++++++------------------- 1 file changed, 65 insertions(+), 58 deletions(-) diff --git a/src/Flf/RecognizerV2.cc b/src/Flf/RecognizerV2.cc index bcf7be5b..35121a11 100644 --- a/src/Flf/RecognizerV2.cc +++ b/src/Flf/RecognizerV2.cc @@ -13,9 +13,10 @@ * limitations under the License. */ #include "RecognizerV2.hh" +#include +#include #include #include -#include "Core/XmlStream.hh" #include "LatticeHandler.hh" #include "Module.hh" @@ -100,11 +101,13 @@ void RecognizerNodeV2::work() { } ConstLatticeRef RecognizerNodeV2::buildLattice(Core::Ref latticeAdaptor, std::string segmentName) { + auto lmScale = modelCombination_->languageModel()->scale(); + auto semiring = Semiring::create(Fsa::SemiringTypeTropical, 2); semiring->setKey(0, "am"); semiring->setScale(0, 1.0); semiring->setKey(1, "lm"); - semiring->setScale(1, modelCombination_->languageModel()->scale()); + semiring->setScale(1, lmScale); auto sentenceEndLabel = Fsa::Epsilon; const Bliss::Lemma* specialSentenceEndLemma = modelCombination_->lexicon()->specialLemma("sentence-end"); @@ -123,78 +126,82 @@ ConstLatticeRef RecognizerNodeV2::buildLattice(Core::Refpart(::Lattice::WordLattice::lmFsa); require_(Fsa::isAcyclic(amFsa) && Fsa::isAcyclic(lmFsa)); - StaticBoundariesRef b = StaticBoundariesRef(new StaticBoundaries); - StaticLatticeRef s = StaticLatticeRef(new StaticLattice); - s->setType(Fsa::TypeAcceptor); - s->setProperties(Fsa::PropertyAcyclic | PropertyCrossWord, Fsa::PropertyAll); - s->setInputAlphabet(modelCombination_->lexicon()->lemmaPronunciationAlphabet()); - s->setSemiring(semiring); - s->setDescription(Core::form("recog(%s)", segmentName.c_str())); - s->setBoundaries(ConstBoundariesRef(b)); - s->setInitialStateId(0); + StaticBoundariesRef flfBoundaries = StaticBoundariesRef(new StaticBoundaries); + StaticLatticeRef flfLattice = StaticLatticeRef(new StaticLattice); + flfLattice->setType(Fsa::TypeAcceptor); + flfLattice->setProperties(Fsa::PropertyAcyclic | PropertyCrossWord, Fsa::PropertyAll); + flfLattice->setInputAlphabet(modelCombination_->lexicon()->lemmaPronunciationAlphabet()); + flfLattice->setSemiring(semiring); + flfLattice->setDescription(Core::form("recog(%s)", segmentName.c_str())); + flfLattice->setBoundaries(ConstBoundariesRef(flfBoundaries)); + flfLattice->setInitialStateId(0); Time timeOffset = (*boundaries)[amFsa->initialStateId()].time(); Fsa::Stack stateStack; - Core::Vector sidMap(amFsa->initialStateId() + 1, Fsa::InvalidStateId); - sidMap[amFsa->initialStateId()] = 0; + Core::Vector stateIdMap(amFsa->initialStateId() + 1, Fsa::InvalidStateId); + stateIdMap[amFsa->initialStateId()] = 0; stateStack.push_back(amFsa->initialStateId()); - Fsa::StateId nextSid = 2; - Time finalTime = 0; + Fsa::StateId nextStateId = 2; + Time finalTime = 0; while (not stateStack.isEmpty()) { - Fsa::StateId sid = stateStack.pop(); - verify(sid < sidMap.size()); - const ::Lattice::WordBoundary& boundary((*boundaries)[sid]); - Fsa::ConstStateRef amSr = amFsa->getState(sid); - Fsa::ConstStateRef lmSr = lmFsa->getState(sid); - State* sp = new State(sidMap[sid]); - s->setState(sp); - b->set(sp->id(), Boundary(boundary.time() - timeOffset, - Boundary::Transit(boundary.transit().final, boundary.transit().initial))); - if (amSr->isFinal()) { + Fsa::StateId stateId = stateStack.pop(); + verify(stateId < stateIdMap.size()); + const ::Lattice::WordBoundary& boundary((*boundaries)[stateId]); + Fsa::ConstStateRef amFsaState = amFsa->getState(stateId); + Fsa::ConstStateRef lmFsaState = lmFsa->getState(stateId); + State* flfState = new State(stateIdMap[stateId]); + flfLattice->setState(flfState); + flfBoundaries->set(flfState->id(), Boundary(boundary.time() - timeOffset, + Boundary::Transit(boundary.transit().final, boundary.transit().initial))); + if (amFsaState->isFinal()) { auto scores = semiring->create(); - scores->set(0, amSr->weight()); - scores->set(1, static_cast(lmSr->weight()) / semiring->scale(1)); - sp->newArc(1, scores, sentenceEndLabel); + scores->set(0, amFsaState->weight()); + if (lmScale) { + scores->set(1, static_cast(lmFsaState->weight()) / lmScale); + } + else { + scores->set(1, 0.0); + } + flfState->newArc(1, scores, sentenceEndLabel); finalTime = std::max(finalTime, boundary.time() - timeOffset); } - for (Fsa::State::const_iterator am_a = amSr->begin(), lm_a = lmSr->begin(); (am_a != amSr->end()) && (lm_a != lmSr->end()); ++am_a, ++lm_a) { - sidMap.grow(am_a->target(), Fsa::InvalidStateId); - if (sidMap[am_a->target()] == Fsa::InvalidStateId) { - sidMap[am_a->target()] = nextSid++; - stateStack.push(am_a->target()); + for (Fsa::State::const_iterator amArc = amFsaState->begin(), lmArc = lmFsaState->begin(); (amArc != amFsaState->end()) && (lmArc != lmFsaState->end()); ++amArc, ++lmArc) { + stateIdMap.grow(amArc->target(), Fsa::InvalidStateId); + if (stateIdMap[amArc->target()] == Fsa::InvalidStateId) { + stateIdMap[amArc->target()] = nextStateId++; + stateStack.push(amArc->target()); } - Fsa::ConstStateRef targetAmSr = amFsa->getState(am_a->target()); - Fsa::ConstStateRef targetLmSr = amFsa->getState(lm_a->target()); - if (targetAmSr->isFinal() && targetLmSr->isFinal()) { - if (am_a->input() == Fsa::Epsilon) { - auto scores = semiring->create(); - scores->set(0, am_a->weight()); - scores->set(1, static_cast(lm_a->weight()) / semiring->scale(1)); - scores->add(0, Score(targetAmSr->weight())); - scores->add(1, Score(targetLmSr->weight()) / semiring->scale(1)); - sp->newArc(1, scores, sentenceEndLabel); - } - else { - auto scores = semiring->create(); - scores->set(0, am_a->weight()); - scores->set(1, static_cast(lm_a->weight()) / semiring->scale(1)); - sp->newArc(sidMap[am_a->target()], scores, am_a->input()); + Fsa::ConstStateRef targetAmState = amFsa->getState(amArc->target()); + Fsa::ConstStateRef targetLmState = amFsa->getState(lmArc->target()); + + auto scores = semiring->create(); + scores->set(0, amArc->weight()); + + if (lmScale) { + scores->set(1, static_cast(lmArc->weight()) / lmScale); + } + else { + scores->set(1, 0); + } + + if (targetAmState->isFinal() and targetLmState->isFinal() and amArc->input() == Fsa::Epsilon) { + scores->add(0, Score(targetAmState->weight())); + if (lmScale) { + scores->add(1, Score(targetLmState->weight()) / lmScale); } + flfState->newArc(1, scores, sentenceEndLabel); } else { - auto scores = semiring->create(); - scores->set(0, am_a->weight()); - scores->set(1, static_cast(lm_a->weight()) / semiring->scale(1)); - sp->newArc(sidMap[am_a->target()], scores, am_a->input()); + flfState->newArc(stateIdMap[amArc->target()], scores, amArc->input()); } } } - State* sp = new State(1); - sp->setFinal(semiring->clone(semiring->one())); - s->setState(sp); - b->set(sp->id(), Boundary(finalTime)); - return s; + State* finalState = new State(1); + finalState->setFinal(semiring->clone(semiring->one())); + flfLattice->setState(finalState); + flfBoundaries->set(finalState->id(), Boundary(finalTime)); + return flfLattice; } void RecognizerNodeV2::init(std::vector const& arguments) { From e309602a9bb037ff2b57c3aba0dca186384e2007 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Wed, 26 Mar 2025 19:06:38 +0100 Subject: [PATCH 28/52] Fix error string --- src/Flf/RecognizerV2.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Flf/RecognizerV2.cc b/src/Flf/RecognizerV2.cc index 35121a11..3740b488 100644 --- a/src/Flf/RecognizerV2.cc +++ b/src/Flf/RecognizerV2.cc @@ -212,7 +212,7 @@ void RecognizerNodeV2::init(std::vector const& arguments) { Lexicon::us())); searchAlgorithm_->setModelCombination(*modelCombination_); if (not connected(0)) { - criticalError("Speech segment at port 1 required"); + criticalError("Speech segment at port 0 required"); } } From 52dc19269be652488956bbfe28775b6dbb07bb1e Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Thu, 27 Mar 2025 16:01:51 +0100 Subject: [PATCH 29/52] Add DataView class to replace feature input/output of LabelScorer and Encoder --- src/Flf/RecognizerV2.cc | 2 +- src/Nn/LabelScorer/BufferedLabelScorer.cc | 11 +-- src/Nn/LabelScorer/BufferedLabelScorer.hh | 8 +- src/Nn/LabelScorer/CombineLabelScorer.cc | 8 +- src/Nn/LabelScorer/CombineLabelScorer.hh | 4 +- src/Nn/LabelScorer/DataView.cc | 48 ++++++++++++ src/Nn/LabelScorer/DataView.hh | 74 +++++++++++++++++++ src/Nn/LabelScorer/Encoder.cc | 28 ++----- src/Nn/LabelScorer/Encoder.hh | 18 ++--- src/Nn/LabelScorer/LabelScorer.cc | 19 +---- src/Nn/LabelScorer/LabelScorer.hh | 6 +- src/Nn/LabelScorer/Makefile | 1 + src/Nn/LabelScorer/NoOpLabelScorer.cc | 7 +- src/Onnx/OnnxEncoder.cc | 16 ++-- .../LexiconfreeTimesyncBeamSearch.cc | 14 +--- .../LexiconfreeTimesyncBeamSearch.hh | 5 +- src/Search/SearchV2.hh | 5 +- src/Test/Makefile | 3 + src/Tools/Archiver/Makefile | 3 + src/Tools/CorpusStatistics/Makefile | 3 + src/Tools/FeatureExtraction/Makefile | 3 + src/Tools/FeatureStatistics/Makefile | 3 + src/Tools/LatticeProcessor/Makefile | 3 + src/Tools/NnTrainer/Makefile | 3 + 24 files changed, 195 insertions(+), 100 deletions(-) create mode 100644 src/Nn/LabelScorer/DataView.cc create mode 100644 src/Nn/LabelScorer/DataView.hh diff --git a/src/Flf/RecognizerV2.cc b/src/Flf/RecognizerV2.cc index 3740b488..5fe3e2c8 100644 --- a/src/Flf/RecognizerV2.cc +++ b/src/Flf/RecognizerV2.cc @@ -58,7 +58,7 @@ void RecognizerNodeV2::recognizeSegment(const Bliss::SpeechSegment* segment) { // Loop over features and perform recognition do { - searchAlgorithm_->putFeature(*feature->mainStream()); + searchAlgorithm_->putFeature(feature->mainStream()); endTime = feature->timestamp().endTime(); } while (dataSource->getData(feature)); diff --git a/src/Nn/LabelScorer/BufferedLabelScorer.cc b/src/Nn/LabelScorer/BufferedLabelScorer.cc index 8f295e3e..c561f6bf 100644 --- a/src/Nn/LabelScorer/BufferedLabelScorer.cc +++ b/src/Nn/LabelScorer/BufferedLabelScorer.cc @@ -21,13 +21,11 @@ BufferedLabelScorer::BufferedLabelScorer(Core::Configuration const& config) : Core::Component(config), Precursor(config), inputBuffer_(), - featureSize_(Core::Type::max), expectMoreFeatures_(true) { } void BufferedLabelScorer::reset() { inputBuffer_.clear(); - featureSize_ = Core::Type::max; expectMoreFeatures_ = true; } @@ -35,14 +33,7 @@ void BufferedLabelScorer::signalNoMoreFeatures() { expectMoreFeatures_ = false; } -void BufferedLabelScorer::addInput(std::shared_ptr const& input, size_t featureSize) { - if (featureSize_ == Core::Type::max) { - featureSize_ = featureSize; - } - else if (featureSize_ != featureSize) { - error() << "Label scorer received incompatible feature size " << featureSize << "; was set to " << featureSize_ << " before."; - } - +void BufferedLabelScorer::addInput(DataView const& input) { inputBuffer_.push_back(input); } diff --git a/src/Nn/LabelScorer/BufferedLabelScorer.hh b/src/Nn/LabelScorer/BufferedLabelScorer.hh index d78d1b7f..566d6ead 100644 --- a/src/Nn/LabelScorer/BufferedLabelScorer.hh +++ b/src/Nn/LabelScorer/BufferedLabelScorer.hh @@ -16,6 +16,7 @@ #ifndef BUFFERED_LABEL_SCORER_HH #define BUFFERED_LABEL_SCORER_HH +#include "DataView.hh" #include "LabelScorer.hh" namespace Nn { @@ -39,12 +40,11 @@ public: virtual void signalNoMoreFeatures() override; // Add a single input feature to the buffer - virtual void addInput(std::shared_ptr const& input, size_t featureSize) override; + virtual void addInput(DataView const& input) override; protected: - std::vector> inputBuffer_; // Buffer that contains all the feature data for the current segment - size_t featureSize_; // Feature dimension size of features in the buffer (same for all features) - bool expectMoreFeatures_; // Flag to record segment end signal + std::vector inputBuffer_; // Buffer that contains all the feature data for the current segment + bool expectMoreFeatures_; // Flag to record segment end signal }; } // namespace Nn diff --git a/src/Nn/LabelScorer/CombineLabelScorer.cc b/src/Nn/LabelScorer/CombineLabelScorer.cc index 833ace5a..cf69cba4 100644 --- a/src/Nn/LabelScorer/CombineLabelScorer.cc +++ b/src/Nn/LabelScorer/CombineLabelScorer.cc @@ -71,15 +71,15 @@ ScoringContextRef CombineLabelScorer::extendedScoringContext(Request const& requ return Core::ref(new CombineScoringContext(std::move(extScoringContexts))); } -void CombineLabelScorer::addInput(std::shared_ptr const& input, size_t featureSize) { +void CombineLabelScorer::addInput(DataView const& input) { for (auto& scaledScorer : scaledScorers_) { - scaledScorer.scorer->addInput(input, featureSize); + scaledScorer.scorer->addInput(input); } } -void CombineLabelScorer::addInputs(std::shared_ptr const& input, size_t timeSize, size_t featureSize) { +void CombineLabelScorer::addInputs(DataView const& input, size_t nTimesteps) { for (auto& scaledScorer : scaledScorers_) { - scaledScorer.scorer->addInputs(input, timeSize, featureSize); + scaledScorer.scorer->addInputs(input, nTimesteps); } } diff --git a/src/Nn/LabelScorer/CombineLabelScorer.hh b/src/Nn/LabelScorer/CombineLabelScorer.hh index 0a90f9cb..01135686 100644 --- a/src/Nn/LabelScorer/CombineLabelScorer.hh +++ b/src/Nn/LabelScorer/CombineLabelScorer.hh @@ -50,10 +50,10 @@ public: ScoringContextRef extendedScoringContext(Request const& request); // Add input to all sub-scorers - void addInput(std::shared_ptr const& input, size_t featureSize); + void addInput(DataView const& input); // Add inputs to all sub-scorers - virtual void addInputs(std::shared_ptr const& input, size_t timeSize, size_t featureSize); + virtual void addInputs(DataView const& input, size_t nTimesteps); // Compute weighted score of request with all sub-scorers std::optional computeScoreWithTime(Request const& request); diff --git a/src/Nn/LabelScorer/DataView.cc b/src/Nn/LabelScorer/DataView.cc new file mode 100644 index 00000000..86e19c15 --- /dev/null +++ b/src/Nn/LabelScorer/DataView.cc @@ -0,0 +1,48 @@ +#include "DataView.hh" + +namespace Nn { + +DataView::DataView(DataView const& dataView) + : dataPtr_(dataView) { +} + +DataView::DataView(std::shared_ptr const& ptr, size_t size, size_t offset) + : size_(size) { + // Use aliasing constructor to create sub-shared_ptr that shares ownership with the original one but points to offset memory location + dataPtr_ = std::shared_ptr(ptr, ptr.get() + offset); +} + +DataView::DataView(Core::Ref const& featureVectorRef) : size_(featureVectorRef->size()) { + // Copy Ref in custom deleter to keep it alive + dataPtr_ = std::shared_ptr( + featureVectorRef->data(), + [featureVectorRef](f32 const[]) mutable {}); +} + +#ifdef MODULE_ONNX +DataView::DataView(Onnx::Value&& value) { + // Move Onnx value into a shared_ptr to enable ref counting without requiring a copy + auto valuePtr = std::make_shared(std::move(value)); + + // Create f32 shared_ptr based on Onnx value shared_ptr + dataPtr_ = std::shared_ptr( + valuePtr->data(), + [valuePtr](f32 const[]) mutable {}); + + size_ = 1ul; + for (int d = 0ul; d < valuePtr->numDims(); ++d) { + size_ *= valuePtr->dimSize(d); + } +} +#endif + +#ifdef MODULE_PYTHON +DataView::DataView(pybind11::array_t const& array, size_t size, size_t offset) : size_(size) { + // Copy array (increasing its ref counter) in custom deleter to keep it alive + dataPtr_ = std::shared_ptr( + array.data() + offset, + [array](f32 const[]) mutable {}); +} +#endif + +} // namespace Nn diff --git a/src/Nn/LabelScorer/DataView.hh b/src/Nn/LabelScorer/DataView.hh new file mode 100644 index 00000000..c9933eb9 --- /dev/null +++ b/src/Nn/LabelScorer/DataView.hh @@ -0,0 +1,74 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATA_VIEW +#define DATA_VIEW + +#include + +#ifdef MODULE_ONNX +#include +#endif + +#ifdef MODULE_PYTHON +#pragma push_macro("ensure") // Macro duplication in numpy.h +#undef ensure +#include +#include +#pragma pop_macro("ensure") +#endif + +namespace Nn { + +/* + * Wraps the data of various data structures in a `std::shared_ptr` + * without copying while making sure that the data is not invalidated. + * This is achieved via custom deleters in the shared_ptr's which tie + * the lifetime of the original datastructure to the shared_ptr. + */ +class DataView { +public: + DataView(DataView const& dataView); + DataView(Core::Ref const& featureVectorRef); + DataView(std::shared_ptr const& ptr, size_t size, size_t offset = 0ul); + +#ifdef MODULE_ONNX + DataView(Onnx::Value&& value); +#endif + +#ifdef MODULE_ONNX + DataView(pybind11::array_t const& array, size_t size, size_t offset = 0ul); +#endif + + operator std::shared_ptr() const { + return dataPtr_; + } + + f32 const* data() const { + return dataPtr_.get(); + } + + size_t size() const { + return size_; + } + +private: + std::shared_ptr dataPtr_; + size_t size_; +}; + +} // namespace Nn + +#endif // DATA_VIEW diff --git a/src/Nn/LabelScorer/Encoder.cc b/src/Nn/LabelScorer/Encoder.cc index d8a1e553..e5170097 100644 --- a/src/Nn/LabelScorer/Encoder.cc +++ b/src/Nn/LabelScorer/Encoder.cc @@ -21,17 +21,12 @@ Encoder::Encoder(Core::Configuration const& config) : Core::Component(config), inputBuffer_(), outputBuffer_(), - featureSize_(Core::Type::max), - outputSize_(Core::Type::max), expectMoreFeatures_(true) {} void Encoder::reset() { expectMoreFeatures_ = true; inputBuffer_.clear(); - featureSize_ = Core::Type::max; - outputSize_ = Core::Type::max; - outputBuffer_.clear(); } @@ -39,21 +34,14 @@ void Encoder::signalNoMoreFeatures() { expectMoreFeatures_ = false; } -void Encoder::addInput(std::shared_ptr const& input, size_t featureSize) { - if (featureSize_ == Core::Type::max) { - featureSize_ = featureSize; - } - else if (featureSize_ != featureSize) { - error() << "Encoder received incompatible feature size " << featureSize << "; was set to " << featureSize_ << " before."; - } - +void Encoder::addInput(DataView const& input) { inputBuffer_.push_back(input); } -void Encoder::addInputs(std::shared_ptr const& input, size_t timeSize, size_t featureSize) { - for (size_t t = 0ul; t < timeSize; ++t) { - // Use aliasing constructor to create sub-`shared_ptr`s that share ownership with the original one but point to different memory locations - addInput(std::shared_ptr(input, input.get() + t * featureSize), featureSize); +void Encoder::addInputs(DataView const& input, size_t nTimesteps) { + auto featureSize = input.size() / nTimesteps; + for (size_t t = 0ul; t < nTimesteps; ++t) { + addInput({input, featureSize, t * featureSize}); } } @@ -61,7 +49,7 @@ bool Encoder::canEncode() const { return not inputBuffer_.empty() and not expectMoreFeatures_; } -std::optional> Encoder::getNextOutput() { +std::optional Encoder::getNextOutput() { // Check if there are still outputs in the buffer to pass if (not outputBuffer_.empty()) { auto result = outputBuffer_.front(); @@ -87,10 +75,6 @@ std::optional> Encoder::getNextOutput() { return getNextOutput(); } -size_t Encoder::getOutputSize() const { - return outputSize_; -} - void Encoder::postEncodeCleanup() { inputBuffer_.clear(); } diff --git a/src/Nn/LabelScorer/Encoder.hh b/src/Nn/LabelScorer/Encoder.hh index 7c9827ea..648cda71 100644 --- a/src/Nn/LabelScorer/Encoder.hh +++ b/src/Nn/LabelScorer/Encoder.hh @@ -20,6 +20,7 @@ #include #include +#include "DataView.hh" namespace Nn { @@ -40,26 +41,21 @@ public: void signalNoMoreFeatures(); // Add a single input feature - virtual void addInput(std::shared_ptr const& input, size_t featureSize); + virtual void addInput(DataView const& input); // Add input features for multiple time steps at once - virtual void addInputs(std::shared_ptr const& inputs, size_t timeSize, size_t featureSize); + virtual void addInputs(DataView const& inputs, size_t nTimesteps); // Retrieve the next encoder output frame // Performs encoder forwarding internally if necessary // Can return None if not enough input features are available yet - std::optional> getNextOutput(); - - // Get dimension of outputs that are fetched via `getNextOutput`. - size_t getOutputSize() const; + std::optional getNextOutput(); protected: - std::deque> inputBuffer_; - std::deque> outputBuffer_; + std::deque inputBuffer_; + std::deque outputBuffer_; - size_t featureSize_; - size_t outputSize_; - bool expectMoreFeatures_; + bool expectMoreFeatures_; // Encode features inside the input buffer and put the results into the output buffer virtual void encode() = 0; diff --git a/src/Nn/LabelScorer/LabelScorer.cc b/src/Nn/LabelScorer/LabelScorer.cc index 0a0c95ca..010aa56a 100644 --- a/src/Nn/LabelScorer/LabelScorer.cc +++ b/src/Nn/LabelScorer/LabelScorer.cc @@ -26,21 +26,10 @@ namespace Nn { LabelScorer::LabelScorer(const Core::Configuration& config) : Core::Component(config) {} -void LabelScorer::addInput(std::vector const& input) { - // The custom deleter ties the lifetime of vector `input` to the lifetime - // of `dataPtr` by capturing the `inputWrapper` by value. - // This makes sure that the underlying data isn't invalidated prematurely. - auto inputWrapper = std::make_shared>(input); - auto dataPtr = std::shared_ptr( - inputWrapper->data(), - [inputWrapper](const f32*) mutable {}); - addInput(dataPtr, input.size()); -} - -void LabelScorer::addInputs(std::shared_ptr const& input, size_t timeSize, size_t featureSize) { - for (size_t t = 0ul; t < timeSize; ++t) { - // Use aliasing constructor to create sub-`shared_ptr`s that share ownership with the original one but point to different memory locations - addInput(std::shared_ptr(input, input.get() + t * featureSize), featureSize); +void LabelScorer::addInputs(DataView const& input, size_t nTimesteps) { + auto featureSize = input.size() / nTimesteps; + for (size_t t = 0ul; t < nTimesteps; ++t) { + addInput({input, featureSize, t * featureSize}); } } diff --git a/src/Nn/LabelScorer/LabelScorer.hh b/src/Nn/LabelScorer/LabelScorer.hh index 788edbb7..40b26007 100644 --- a/src/Nn/LabelScorer/LabelScorer.hh +++ b/src/Nn/LabelScorer/LabelScorer.hh @@ -32,6 +32,7 @@ #include #include +#include "DataView.hh" #include "ScoringContext.hh" namespace Nn { @@ -123,11 +124,10 @@ public: virtual ScoringContextRef extendedScoringContext(Request const& request) = 0; // Add a single input feature - virtual void addInput(std::shared_ptr const& input, size_t featureSize) = 0; - virtual void addInput(std::vector const& input); + virtual void addInput(DataView const& input) = 0; // Add input features for multiple time steps at once - virtual void addInputs(std::shared_ptr const& input, size_t timeSize, size_t featureSize); + virtual void addInputs(DataView const& input, size_t nTimesteps); // Perform scoring computation for a single request // Return score and timeframe index of the corresponding output diff --git a/src/Nn/LabelScorer/Makefile b/src/Nn/LabelScorer/Makefile index 50bc621b..3800d32a 100644 --- a/src/Nn/LabelScorer/Makefile +++ b/src/Nn/LabelScorer/Makefile @@ -12,6 +12,7 @@ TARGETS = libSprintLabelScorer.$(a) LIBSPRINTLABELSCORER_O = \ $(OBJDIR)/BufferedLabelScorer.o \ $(OBJDIR)/CombineLabelScorer.o \ + $(OBJDIR)/DataView.o \ $(OBJDIR)/Encoder.o \ $(OBJDIR)/EncoderFactory.o \ $(OBJDIR)/LabelScorer.o \ diff --git a/src/Nn/LabelScorer/NoOpLabelScorer.cc b/src/Nn/LabelScorer/NoOpLabelScorer.cc index c72c6fc3..828a8f3f 100644 --- a/src/Nn/LabelScorer/NoOpLabelScorer.cc +++ b/src/Nn/LabelScorer/NoOpLabelScorer.cc @@ -35,11 +35,12 @@ std::optional StepwiseNoOpLabelScorer::computeScoreW if (inputBuffer_.size() <= stepHistory->currentStep) { return {}; } - if (request.nextToken >= featureSize_) { - error() << "Tried to get score for token index " << request.nextToken << " but only have " << featureSize_ << " scores available."; + auto& currentInput = inputBuffer_.at(stepHistory->currentStep); + if (request.nextToken >= currentInput.size()) { + error() << "Tried to get score for token index " << request.nextToken << " but only have " << currentInput.size() << " scores available."; } - return ScoreWithTime{inputBuffer_.at(stepHistory->currentStep)[request.nextToken], stepHistory->currentStep}; + return ScoreWithTime{currentInput.data()[request.nextToken], stepHistory->currentStep}; } } // namespace Nn diff --git a/src/Onnx/OnnxEncoder.cc b/src/Onnx/OnnxEncoder.cc index d37dc51f..47e41f36 100644 --- a/src/Onnx/OnnxEncoder.cc +++ b/src/Onnx/OnnxEncoder.cc @@ -60,14 +60,14 @@ void OnnxEncoder::encode() { std::vector> sessionInputs; size_t T_in = inputBuffer_.size(); - size_t F = featureSize_; + size_t F = inputBuffer_.front().size(); std::vector featuresShape = {1l, static_cast(T_in), static_cast(F)}; Value value = Value::createEmpty(featuresShape); for (size_t t = 0ul; t < T_in; ++t) { - std::copy(inputBuffer_[t].get(), inputBuffer_[t].get() + F, value.data(0, t)); + std::copy(inputBuffer_[t].data(), inputBuffer_[t].data() + F, value.data(0, t)); } sessionInputs.emplace_back(std::make_pair(featuresName_, std::move(value))); @@ -85,16 +85,14 @@ void OnnxEncoder::encode() { /* * Put outputs into buffer */ - auto onnxOutputValueWrapper = std::make_shared(std::move(sessionOutputs.front())); + size_t T_out = sessionOutputs.front().dimSize(1); + size_t outputSize = sessionOutputs.front().dimSize(2); - size_t T_out = onnxOutputValueWrapper->dimSize(1); - outputSize_ = onnxOutputValueWrapper->dimSize(2); + // Make "global" DataView from output value so that feature slice DataViews can be created from it that ref-count the original value + Nn::DataView onnxOutputView(std::move(sessionOutputs.front())); for (size_t t = 0ul; t < T_out; ++t) { - auto frameOutputPtr = std::shared_ptr( - onnxOutputValueWrapper->data() + t * outputSize_, - [onnxOutputValueWrapper](const f32*) mutable {}); - outputBuffer_.push_back(frameOutputPtr); + outputBuffer_.push_back({onnxOutputView, outputSize, t * outputSize}); } } diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc index 901e31e5..4eddeb94 100644 --- a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc @@ -135,21 +135,15 @@ void LexiconfreeTimesyncBeamSearch::finishSegment() { logStatistics(); } -void LexiconfreeTimesyncBeamSearch::putFeature(std::shared_ptr const& data, size_t featureSize) { +void LexiconfreeTimesyncBeamSearch::putFeature(Nn::DataView const& feature) { featureProcessingTime_.start(); - labelScorer_->addInput(data, featureSize); + labelScorer_->addInput(feature); featureProcessingTime_.stop(); } -void LexiconfreeTimesyncBeamSearch::putFeature(std::vector const& data) { +void LexiconfreeTimesyncBeamSearch::putFeatures(Nn::DataView const& features, size_t nTimesteps) { featureProcessingTime_.start(); - labelScorer_->addInput(data); - featureProcessingTime_.stop(); -} - -void LexiconfreeTimesyncBeamSearch::putFeatures(std::shared_ptr const& data, size_t timeSize, size_t featureSize) { - featureProcessingTime_.start(); - labelScorer_->addInputs(data, timeSize, featureSize); + labelScorer_->addInputs(features, nTimesteps); featureProcessingTime_.stop(); } diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh index 71f19317..2718035b 100644 --- a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh @@ -91,9 +91,8 @@ public: void reset() override; void enterSegment(Bliss::SpeechSegment const* = nullptr) override; void finishSegment() override; - void putFeature(std::shared_ptr const& data, size_t featureSize) override; - void putFeature(std::vector const& data) override; - void putFeatures(std::shared_ptr const& data, size_t timeSize, size_t featureSize) override; + void putFeature(Nn::DataView const& feature) override; + void putFeatures(Nn::DataView const& features, size_t nTimesteps) override; Core::Ref getCurrentBestTraceback() const override; Core::Ref getCurrentBestWordLattice() const override; bool decodeStep() override; diff --git a/src/Search/SearchV2.hh b/src/Search/SearchV2.hh index 4548b29a..aa695036 100644 --- a/src/Search/SearchV2.hh +++ b/src/Search/SearchV2.hh @@ -78,11 +78,10 @@ public: virtual void finishSegment() = 0; // Pass a single feature vector. - virtual void putFeature(std::shared_ptr const& data, size_t featureSize) = 0; - virtual void putFeature(std::vector const& data) = 0; + virtual void putFeature(Nn::DataView const& feature) = 0; // Pass feature vectors for multiple time steps. - virtual void putFeatures(std::shared_ptr const& data, size_t timeSize, size_t featureSize) = 0; + virtual void putFeatures(Nn::DataView const& features, size_t nTimesteps) = 0; // Return the current best traceback. May contain unstable results. virtual Core::Ref getCurrentBestTraceback() const = 0; diff --git a/src/Test/Makefile b/src/Test/Makefile index 52c80e5a..a42b9aec 100644 --- a/src/Test/Makefile +++ b/src/Test/Makefile @@ -104,6 +104,9 @@ endif ifdef MODULE_OPENFST UNIT_TEST_O += ../OpenFst/libSprintOpenFst.$(a) endif +ifdef MODULE_ONNX +UNIT_TEST_O += ../Onnx/libSprintOnnx.$(a) +endif ifdef MODULE_TENSORFLOW UNIT_TEST_O += ../Tensorflow/libSprintTensorflow.$(a) CXXFLAGS += $(TF_CXXFLAGS) diff --git a/src/Tools/Archiver/Makefile b/src/Tools/Archiver/Makefile index 6e1b762a..5da403af 100644 --- a/src/Tools/Archiver/Makefile +++ b/src/Tools/Archiver/Makefile @@ -43,6 +43,9 @@ endif ifdef MODULE_OPENFST ARCHIVER_O += ../../OpenFst/libSprintOpenFst.$(a) endif +ifdef MODULE_ONNX +ARCHIVER_O += ../../Onnx/libSprintOnnx.$(a) +endif ifdef MODULE_LM_TFRNN ARCHIVER_O += ../../Tensorflow/libSprintTensorflow.$(a) CXXFLAGS += $(TF_CXXFLAGS) diff --git a/src/Tools/CorpusStatistics/Makefile b/src/Tools/CorpusStatistics/Makefile index 0a30e754..b4579a28 100644 --- a/src/Tools/CorpusStatistics/Makefile +++ b/src/Tools/CorpusStatistics/Makefile @@ -40,6 +40,9 @@ endif ifdef MODULE_NN CORPUS_STATISTICS_O += ../../Nn/libSprintNn.$(a) endif +ifdef MODULE_ONNX +CORPUS_STATISTICS_O += ../../Onnx/libSprintOnnx.$(a) +endif ifdef MODULE_TENSORFLOW CORPUS_STATISTICS_O += ../../Tensorflow/libSprintTensorflow.$(a) CXXFLAGS += $(TF_CXXFLAGS) diff --git a/src/Tools/FeatureExtraction/Makefile b/src/Tools/FeatureExtraction/Makefile index 17c4451e..a5eee059 100644 --- a/src/Tools/FeatureExtraction/Makefile +++ b/src/Tools/FeatureExtraction/Makefile @@ -46,6 +46,9 @@ endif ifdef MODULE_NN FEATURE_EXTRACTION_O += ../../Nn/libSprintNn.$(a) endif +ifdef MODULE_ONNX +FEATURE_EXTRACTION_O += ../../Onnx/libSprintOnnx.$(a) +endif ifdef MODULE_TENSORFLOW FEATURE_EXTRACTION_O += ../../Tensorflow/libSprintTensorflow.$(a) CXXFLAGS += $(TF_CXXFLAGS) diff --git a/src/Tools/FeatureStatistics/Makefile b/src/Tools/FeatureStatistics/Makefile index ccc7041b..7ab3285b 100644 --- a/src/Tools/FeatureStatistics/Makefile +++ b/src/Tools/FeatureStatistics/Makefile @@ -43,6 +43,9 @@ endif ifdef MODULE_NN FEATURE_STATISTICS_O += ../../Nn/libSprintNn.$(a) endif +ifdef MODULE_ONNX +FEATURE_STATISTICS_O += ../../Onnx/libSprintOnnx.$(a) +endif ifdef MODULE_TENSORFLOW FEATURE_STATISTICS_O += ../../Tensorflow/libSprintTensorflow.$(a) CXXFLAGS += $(TF_CXXFLAGS) diff --git a/src/Tools/LatticeProcessor/Makefile b/src/Tools/LatticeProcessor/Makefile index 40d67f10..a360a95b 100644 --- a/src/Tools/LatticeProcessor/Makefile +++ b/src/Tools/LatticeProcessor/Makefile @@ -40,6 +40,9 @@ endif ifdef MODULE_NN LATTICE_PROCESSOR_O += ../../Nn/libSprintNn.$(a) endif +ifdef MODULE_ONNX +LATTICE_PROCESSOR_O += ../../Onnx/libSprintOnnx.$(a) +endif ifdef MODULE_TENSORFLOW LATTICE_PROCESSOR_O += ../../Tensorflow/libSprintTensorflow.$(a) CXXFLAGS += $(TF_CXXFLAGS) diff --git a/src/Tools/NnTrainer/Makefile b/src/Tools/NnTrainer/Makefile index 707285ce..b15a05eb 100644 --- a/src/Tools/NnTrainer/Makefile +++ b/src/Tools/NnTrainer/Makefile @@ -47,6 +47,9 @@ endif ifdef MODULE_PYTHON NN_TRAINER_O += ../../Python/libSprintPython.$(a) endif +ifdef MODULE_ONNX +NN_TRAINER_O += ../../Onnx/libSprintOnnx.$(a) +endif ifdef MODULE_TENSORFLOW NN_TRAINER_O += ../../Tensorflow/libSprintTensorflow.$(a) CXXFLAGS += $(TF_CXXFLAGS) From 8dd7251220b166f819f69e26a4fe106b9362aebb Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Thu, 27 Mar 2025 16:25:00 +0100 Subject: [PATCH 30/52] Rewrite docstring, remove static_ptr cast, add operator[] function --- src/Nn/LabelScorer/DataView.cc | 6 +++++- src/Nn/LabelScorer/DataView.hh | 21 +++++++++++++-------- src/Nn/LabelScorer/NoOpLabelScorer.cc | 6 +----- 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/src/Nn/LabelScorer/DataView.cc b/src/Nn/LabelScorer/DataView.cc index 86e19c15..d07c4784 100644 --- a/src/Nn/LabelScorer/DataView.cc +++ b/src/Nn/LabelScorer/DataView.cc @@ -3,7 +3,11 @@ namespace Nn { DataView::DataView(DataView const& dataView) - : dataPtr_(dataView) { + : dataPtr_(dataView.dataPtr_), size_(dataView.size_) { +} + +DataView::DataView(DataView const& dataView, size_t size, size_t offset) + : dataPtr_(dataView.dataPtr_, dataView.data() + offset), size_(size) { } DataView::DataView(std::shared_ptr const& ptr, size_t size, size_t offset) diff --git a/src/Nn/LabelScorer/DataView.hh b/src/Nn/LabelScorer/DataView.hh index c9933eb9..fb7f94a1 100644 --- a/src/Nn/LabelScorer/DataView.hh +++ b/src/Nn/LabelScorer/DataView.hh @@ -33,14 +33,18 @@ namespace Nn { /* - * Wraps the data of various data structures in a `std::shared_ptr` - * without copying while making sure that the data is not invalidated. - * This is achieved via custom deleters in the shared_ptr's which tie - * the lifetime of the original datastructure to the shared_ptr. + * This class encapsulates a std::shared_ptr and a size. The internal shared_ptr is tied + * to the lifetime of the original data container in order to make sure + * it stays valid as long as the view is alive. + * + * It can be initialized using various data containers such as a + * Core::Ref, another shared_ptr, an Onnx::Value + * or a pybind11::array_t. */ class DataView { public: DataView(DataView const& dataView); + DataView(DataView const& dataView, size_t size, size_t offset = 0ul); DataView(Core::Ref const& featureVectorRef); DataView(std::shared_ptr const& ptr, size_t size, size_t offset = 0ul); @@ -52,14 +56,15 @@ public: DataView(pybind11::array_t const& array, size_t size, size_t offset = 0ul); #endif - operator std::shared_ptr() const { - return dataPtr_; - } - f32 const* data() const { return dataPtr_.get(); } + f32 operator[](size_t idx) const { + verify(idx < size_); + return dataPtr_[idx]; + } + size_t size() const { return size_; } diff --git a/src/Nn/LabelScorer/NoOpLabelScorer.cc b/src/Nn/LabelScorer/NoOpLabelScorer.cc index 828a8f3f..c20b098a 100644 --- a/src/Nn/LabelScorer/NoOpLabelScorer.cc +++ b/src/Nn/LabelScorer/NoOpLabelScorer.cc @@ -35,12 +35,8 @@ std::optional StepwiseNoOpLabelScorer::computeScoreW if (inputBuffer_.size() <= stepHistory->currentStep) { return {}; } - auto& currentInput = inputBuffer_.at(stepHistory->currentStep); - if (request.nextToken >= currentInput.size()) { - error() << "Tried to get score for token index " << request.nextToken << " but only have " << currentInput.size() << " scores available."; - } - return ScoreWithTime{currentInput.data()[request.nextToken], stepHistory->currentStep}; + return ScoreWithTime{inputBuffer_.at(stepHistory->currentStep)[request.nextToken], stepHistory->currentStep}; } } // namespace Nn From 1cad5984655d8ebd249580df81398cc8f02ebd6a Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Tue, 1 Apr 2025 21:22:40 +0200 Subject: [PATCH 31/52] Fix indentation --- src/Nn/LabelScorer/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Nn/LabelScorer/Makefile b/src/Nn/LabelScorer/Makefile index 3800d32a..0afc0b47 100644 --- a/src/Nn/LabelScorer/Makefile +++ b/src/Nn/LabelScorer/Makefile @@ -12,7 +12,7 @@ TARGETS = libSprintLabelScorer.$(a) LIBSPRINTLABELSCORER_O = \ $(OBJDIR)/BufferedLabelScorer.o \ $(OBJDIR)/CombineLabelScorer.o \ - $(OBJDIR)/DataView.o \ + $(OBJDIR)/DataView.o \ $(OBJDIR)/Encoder.o \ $(OBJDIR)/EncoderFactory.o \ $(OBJDIR)/LabelScorer.o \ From b6ae9d4036661b5eabd229f6ec7c315afcf313fe Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Fri, 4 Apr 2025 17:43:15 +0200 Subject: [PATCH 32/52] Remove unnecessary includes --- src/Search/Traceback.cc | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/Search/Traceback.cc b/src/Search/Traceback.cc index b4fd79ca..cc16ecc2 100644 --- a/src/Search/Traceback.cc +++ b/src/Search/Traceback.cc @@ -14,11 +14,6 @@ */ #include "Traceback.hh" -#include -#include -#include - -#include #include From 6ecb64416ba9fb3240bf6ae7cdf58503df329181 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Fri, 4 Apr 2025 17:44:35 +0200 Subject: [PATCH 33/52] Formatting --- src/Nn/LabelScorer/DataView.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Nn/LabelScorer/DataView.cc b/src/Nn/LabelScorer/DataView.cc index d07c4784..eda928bd 100644 --- a/src/Nn/LabelScorer/DataView.cc +++ b/src/Nn/LabelScorer/DataView.cc @@ -16,7 +16,8 @@ DataView::DataView(std::shared_ptr const& ptr, size_t size, size_t dataPtr_ = std::shared_ptr(ptr, ptr.get() + offset); } -DataView::DataView(Core::Ref const& featureVectorRef) : size_(featureVectorRef->size()) { +DataView::DataView(Core::Ref const& featureVectorRef) + : size_(featureVectorRef->size()) { // Copy Ref in custom deleter to keep it alive dataPtr_ = std::shared_ptr( featureVectorRef->data(), @@ -41,7 +42,8 @@ DataView::DataView(Onnx::Value&& value) { #endif #ifdef MODULE_PYTHON -DataView::DataView(pybind11::array_t const& array, size_t size, size_t offset) : size_(size) { +DataView::DataView(pybind11::array_t const& array, size_t size, size_t offset) + : size_(size) { // Copy array (increasing its ref counter) in custom deleter to keep it alive dataPtr_ = std::shared_ptr( array.data() + offset, From b33902d4214018fab80f5ed53b64dedbbaccf7d8 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Fri, 4 Apr 2025 17:47:44 +0200 Subject: [PATCH 34/52] Fix #define name --- src/Nn/LabelScorer/DataView.hh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Nn/LabelScorer/DataView.hh b/src/Nn/LabelScorer/DataView.hh index fb7f94a1..42953df8 100644 --- a/src/Nn/LabelScorer/DataView.hh +++ b/src/Nn/LabelScorer/DataView.hh @@ -13,8 +13,8 @@ * limitations under the License. */ -#ifndef DATA_VIEW -#define DATA_VIEW +#ifndef DATA_VIEW_HH +#define DATA_VIEW_HH #include @@ -76,4 +76,4 @@ private: } // namespace Nn -#endif // DATA_VIEW +#endif // DATA_VIEW_HH From 277d09d92485e8608b68adf59aa013e2be119552 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Wed, 9 Apr 2025 15:35:54 +0200 Subject: [PATCH 35/52] Update EncoderDecoderLabelScorer --- src/Core/Hash.hh | 1 + .../LabelScorer/EncoderDecoderLabelScorer.cc | 24 +++++-------------- .../LabelScorer/EncoderDecoderLabelScorer.hh | 5 ++-- 3 files changed, 9 insertions(+), 21 deletions(-) diff --git a/src/Core/Hash.hh b/src/Core/Hash.hh index 8db606c3..83167646 100644 --- a/src/Core/Hash.hh +++ b/src/Core/Hash.hh @@ -15,6 +15,7 @@ #ifndef _CORE_HASH_HH #define _CORE_HASH_HH +#include #include #include #include diff --git a/src/Nn/LabelScorer/EncoderDecoderLabelScorer.cc b/src/Nn/LabelScorer/EncoderDecoderLabelScorer.cc index f8c0c038..9c8b4363 100644 --- a/src/Nn/LabelScorer/EncoderDecoderLabelScorer.cc +++ b/src/Nn/LabelScorer/EncoderDecoderLabelScorer.cc @@ -37,25 +37,13 @@ ScoringContextRef EncoderDecoderLabelScorer::extendedScoringContext(Request cons return decoder_->extendedScoringContext(request); } -void EncoderDecoderLabelScorer::addInput(std::shared_ptr const& input, size_t featureSize) { - encoder_->addInput(input, featureSize); +void EncoderDecoderLabelScorer::addInput(DataView const& input) { + encoder_->addInput(input); passEncoderOutputsToDecoder(); } -void EncoderDecoderLabelScorer::addInput(std::vector const& input) { - // The custom deleter ties the lifetime of the vector to the lifetime - // of `dataPtr` by capturing the `inputWrapper` by value. - // This makes sure that the underlying data isn't invalidated prematurely. - auto inputWrapper = std::make_shared>(input); - auto dataPtr = std::shared_ptr( - inputWrapper->data(), - [inputWrapper](const f32*) mutable {}); - encoder_->addInput(dataPtr, input.size()); - passEncoderOutputsToDecoder(); -} - -void EncoderDecoderLabelScorer::addInputs(std::shared_ptr const& input, size_t timeSize, size_t featureSize) { - encoder_->addInputs(input, timeSize, featureSize); +void EncoderDecoderLabelScorer::addInputs(DataView const& input, size_t nTimesteps) { + encoder_->addInputs(input, nTimesteps); passEncoderOutputsToDecoder(); } @@ -76,9 +64,9 @@ std::optional EncoderDecoderLabelScorer::computeSc } void EncoderDecoderLabelScorer::passEncoderOutputsToDecoder() { - std::optional> encoderOutput; + std::optional encoderOutput; while ((encoderOutput = encoder_->getNextOutput())) { - decoder_->addInput(*encoderOutput, encoder_->getOutputSize()); + decoder_->addInput(*encoderOutput); } } diff --git a/src/Nn/LabelScorer/EncoderDecoderLabelScorer.hh b/src/Nn/LabelScorer/EncoderDecoderLabelScorer.hh index 44c970a6..498c1224 100644 --- a/src/Nn/LabelScorer/EncoderDecoderLabelScorer.hh +++ b/src/Nn/LabelScorer/EncoderDecoderLabelScorer.hh @@ -51,11 +51,10 @@ public: // Add an input feature to the encoder component and if possible forward the encoder and add // the encoder states as inputs to the decoder component - void addInput(std::shared_ptr const& input, size_t featureSize) override; - void addInput(std::vector const& input) override; + void addInput(DataView const& input) override; // Same as `addInput` but adds features for multiple timesteps at once - void addInputs(std::shared_ptr const& input, size_t timeSize, size_t featureSize) override; + void addInputs(DataView const& input, size_t nTimesteps) override; // Run request through decoder component std::optional computeScoreWithTime(LabelScorer::Request const& request) override; From 815338de5f615b6b59d73eab40111df1aadc8428 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Wed, 9 Apr 2025 16:29:13 +0200 Subject: [PATCH 36/52] Add cache-cleanup functionality to LabelScorer --- src/Core/CollapsedVector.hh | 26 +++++++++++++++++++ src/Nn/LabelScorer/BufferedLabelScorer.cc | 16 ++++++++++++ src/Nn/LabelScorer/BufferedLabelScorer.hh | 12 +++++++-- src/Nn/LabelScorer/CombineLabelScorer.cc | 6 +++++ src/Nn/LabelScorer/CombineLabelScorer.hh | 19 ++++++++------ .../LabelScorer/EncoderDecoderLabelScorer.cc | 4 +++ .../LabelScorer/EncoderDecoderLabelScorer.hh | 4 +++ src/Nn/LabelScorer/LabelScorer.hh | 8 +++--- src/Nn/LabelScorer/NoOpLabelScorer.cc | 11 ++++++++ src/Nn/LabelScorer/NoOpLabelScorer.hh | 6 ++++- .../LexiconfreeTimesyncBeamSearch.cc | 22 +++++++++++----- 11 files changed, 114 insertions(+), 20 deletions(-) diff --git a/src/Core/CollapsedVector.hh b/src/Core/CollapsedVector.hh index d4d9077d..1a431491 100644 --- a/src/Core/CollapsedVector.hh +++ b/src/Core/CollapsedVector.hh @@ -50,6 +50,12 @@ public: inline void reserve(size_t size); inline const T& front() const; + inline typename std::vector::iterator begin(); + inline typename std::vector::iterator end(); + + inline typename std::vector::const_iterator begin() const; + inline typename std::vector::const_iterator end() const; + private: std::vector data_; size_t logicalSize_; @@ -124,6 +130,26 @@ inline const T& CollapsedVector::front() const { return data_.front(); } +template +inline typename std::vector::iterator CollapsedVector::begin() { + return data_.begin(); +} + +template +inline typename std::vector::iterator CollapsedVector::end() { + return data_.end(); +} + +template +inline typename std::vector::const_iterator CollapsedVector::begin() const { + return data_.begin(); +} + +template +inline typename std::vector::const_iterator CollapsedVector::end() const { + return data_.end(); +} + } // namespace Core #endif // COLLAPSED_VECTOR_HH diff --git a/src/Nn/LabelScorer/BufferedLabelScorer.cc b/src/Nn/LabelScorer/BufferedLabelScorer.cc index c561f6bf..cbc4348d 100644 --- a/src/Nn/LabelScorer/BufferedLabelScorer.cc +++ b/src/Nn/LabelScorer/BufferedLabelScorer.cc @@ -21,11 +21,13 @@ BufferedLabelScorer::BufferedLabelScorer(Core::Configuration const& config) : Core::Component(config), Precursor(config), inputBuffer_(), + numDeletedInputs_(0ul), expectMoreFeatures_(true) { } void BufferedLabelScorer::reset() { inputBuffer_.clear(); + numDeletedInputs_ = 0ul; expectMoreFeatures_ = true; } @@ -37,4 +39,18 @@ void BufferedLabelScorer::addInput(DataView const& input) { inputBuffer_.push_back(input); } +void BufferedLabelScorer::cleanupCaches(Core::CollapsedVector const& activeContexts) { + if (inputBuffer_.empty()) { + return; + } + + auto minActiveTime = minActiveTimeIndex(activeContexts); + if (minActiveTime > numDeletedInputs_) { + size_t deleteInputs = minActiveTime - numDeletedInputs_; + deleteInputs = std::min(deleteInputs, inputBuffer_.size()); + inputBuffer_.erase(inputBuffer_.begin(), inputBuffer_.begin() + deleteInputs); + numDeletedInputs_ += deleteInputs; + } +} + } // namespace Nn diff --git a/src/Nn/LabelScorer/BufferedLabelScorer.hh b/src/Nn/LabelScorer/BufferedLabelScorer.hh index 566d6ead..e6d453f9 100644 --- a/src/Nn/LabelScorer/BufferedLabelScorer.hh +++ b/src/Nn/LabelScorer/BufferedLabelScorer.hh @@ -16,8 +16,10 @@ #ifndef BUFFERED_LABEL_SCORER_HH #define BUFFERED_LABEL_SCORER_HH +#include #include "DataView.hh" #include "LabelScorer.hh" +#include "Speech/Types.hh" namespace Nn { @@ -42,9 +44,15 @@ public: // Add a single input feature to the buffer virtual void addInput(DataView const& input) override; + // Clean up input buffer + virtual void cleanupCaches(Core::CollapsedVector const& activeContexts) override; + protected: - std::vector inputBuffer_; // Buffer that contains all the feature data for the current segment - bool expectMoreFeatures_; // Flag to record segment end signal + std::deque inputBuffer_; // Buffer that contains all the feature data for the current segment + size_t numDeletedInputs_; // Count delted inputs in order to adress the correct index in inputBuffer_ + bool expectMoreFeatures_; // Flag to record segment end signal + + virtual Speech::TimeframeIndex minActiveTimeIndex(Core::CollapsedVector const& activeContexts) const = 0; }; } // namespace Nn diff --git a/src/Nn/LabelScorer/CombineLabelScorer.cc b/src/Nn/LabelScorer/CombineLabelScorer.cc index cf69cba4..fa72656b 100644 --- a/src/Nn/LabelScorer/CombineLabelScorer.cc +++ b/src/Nn/LabelScorer/CombineLabelScorer.cc @@ -71,6 +71,12 @@ ScoringContextRef CombineLabelScorer::extendedScoringContext(Request const& requ return Core::ref(new CombineScoringContext(std::move(extScoringContexts))); } +void CombineLabelScorer::cleanupCaches(Core::CollapsedVector const& activeContexts) { + for (auto& scaledScorer : scaledScorers_) { + scaledScorer.scorer->cleanupCaches(activeContexts); + } +} + void CombineLabelScorer::addInput(DataView const& input) { for (auto& scaledScorer : scaledScorers_) { scaledScorer.scorer->addInput(input); diff --git a/src/Nn/LabelScorer/CombineLabelScorer.hh b/src/Nn/LabelScorer/CombineLabelScorer.hh index 01135686..a6baf688 100644 --- a/src/Nn/LabelScorer/CombineLabelScorer.hh +++ b/src/Nn/LabelScorer/CombineLabelScorer.hh @@ -38,28 +38,31 @@ public: virtual ~CombineLabelScorer() = default; // Reset all sub-scorers - void reset(); + void reset() override; // Forward signal to all sub-scorers - void signalNoMoreFeatures(); + void signalNoMoreFeatures() override; // Combine initial ScoringContexts from all sub-scorers - ScoringContextRef getInitialScoringContext(); + ScoringContextRef getInitialScoringContext() override; // Combine extended ScoringContexts from all sub-scorers - ScoringContextRef extendedScoringContext(Request const& request); + ScoringContextRef extendedScoringContext(Request const& request) override; + + // Cleanup all sub-scorers + void cleanupCaches(Core::CollapsedVector const& activeContexts) override; // Add input to all sub-scorers - void addInput(DataView const& input); + void addInput(DataView const& input) override; // Add inputs to all sub-scorers - virtual void addInputs(DataView const& input, size_t nTimesteps); + virtual void addInputs(DataView const& input, size_t nTimesteps) override; // Compute weighted score of request with all sub-scorers - std::optional computeScoreWithTime(Request const& request); + std::optional computeScoreWithTime(Request const& request) override; // Compute weighted scores of requests with all sub-scorers - std::optional computeScoresWithTimes(const std::vector& requests); + std::optional computeScoresWithTimes(std::vector const& requests) override; protected: struct ScaledLabelScorer { diff --git a/src/Nn/LabelScorer/EncoderDecoderLabelScorer.cc b/src/Nn/LabelScorer/EncoderDecoderLabelScorer.cc index 9c8b4363..257c08e1 100644 --- a/src/Nn/LabelScorer/EncoderDecoderLabelScorer.cc +++ b/src/Nn/LabelScorer/EncoderDecoderLabelScorer.cc @@ -37,6 +37,10 @@ ScoringContextRef EncoderDecoderLabelScorer::extendedScoringContext(Request cons return decoder_->extendedScoringContext(request); } +void EncoderDecoderLabelScorer::cleanupCaches(Core::CollapsedVector const& activeContexts) { + decoder_->cleanupCaches(activeContexts); +} + void EncoderDecoderLabelScorer::addInput(DataView const& input) { encoder_->addInput(input); passEncoderOutputsToDecoder(); diff --git a/src/Nn/LabelScorer/EncoderDecoderLabelScorer.hh b/src/Nn/LabelScorer/EncoderDecoderLabelScorer.hh index 498c1224..204204aa 100644 --- a/src/Nn/LabelScorer/EncoderDecoderLabelScorer.hh +++ b/src/Nn/LabelScorer/EncoderDecoderLabelScorer.hh @@ -49,6 +49,10 @@ public: // Get extended context from decoder component ScoringContextRef extendedScoringContext(Request const& request) override; + // Cleanup decoder component. Encoder is "self-cleaning" already in that it only stores outputs until they are + // retrieved. + void cleanupCaches(Core::CollapsedVector const& activeContexts) override; + // Add an input feature to the encoder component and if possible forward the encoder and add // the encoder states as inputs to the decoder component void addInput(DataView const& input) override; diff --git a/src/Nn/LabelScorer/LabelScorer.hh b/src/Nn/LabelScorer/LabelScorer.hh index 40b26007..6666d28d 100644 --- a/src/Nn/LabelScorer/LabelScorer.hh +++ b/src/Nn/LabelScorer/LabelScorer.hh @@ -74,9 +74,7 @@ namespace Nn { class LabelScorer : public virtual Core::Component, public Core::ReferenceCounted { public: - typedef Search::Score Score; - typedef Flow::Vector FeatureVector; - typedef Flow::DataPtr FeatureVectorRef; + typedef Search::Score Score; enum TransitionType { LABEL_TO_LABEL, @@ -123,6 +121,10 @@ public: // Creates a copy of the context in the request that is extended using the given token and transition type virtual ScoringContextRef extendedScoringContext(Request const& request) = 0; + // Given a collection of currently active contexts, this function can clean up values in any internal caches + // or buffers that are saved for scoring contexts which no longer are active. + virtual void cleanupCaches(Core::CollapsedVector const& activeContexts) {}; + // Add a single input feature virtual void addInput(DataView const& input) = 0; diff --git a/src/Nn/LabelScorer/NoOpLabelScorer.cc b/src/Nn/LabelScorer/NoOpLabelScorer.cc index c20b098a..7d611c83 100644 --- a/src/Nn/LabelScorer/NoOpLabelScorer.cc +++ b/src/Nn/LabelScorer/NoOpLabelScorer.cc @@ -15,6 +15,7 @@ #include "NoOpLabelScorer.hh" #include "ScoringContext.hh" +#include "Speech/Types.hh" namespace Nn { @@ -39,4 +40,14 @@ std::optional StepwiseNoOpLabelScorer::computeScoreW return ScoreWithTime{inputBuffer_.at(stepHistory->currentStep)[request.nextToken], stepHistory->currentStep}; } +Speech::TimeframeIndex StepwiseNoOpLabelScorer::minActiveTimeIndex(Core::CollapsedVector const& activeContexts) const { + auto minTimeIndex = Core::Type::max; + for (auto const& context : activeContexts) { + StepScoringContextRef stepHistory(dynamic_cast(context.get())); + minTimeIndex = std::min(minTimeIndex, stepHistory->currentStep); + } + + return minTimeIndex; +} + } // namespace Nn diff --git a/src/Nn/LabelScorer/NoOpLabelScorer.hh b/src/Nn/LabelScorer/NoOpLabelScorer.hh index 673baace..483539b8 100644 --- a/src/Nn/LabelScorer/NoOpLabelScorer.hh +++ b/src/Nn/LabelScorer/NoOpLabelScorer.hh @@ -17,6 +17,7 @@ #define NO_OP_LABEL_SCORER_HH #include "BufferedLabelScorer.hh" +#include "Speech/Types.hh" namespace Nn { @@ -37,10 +38,13 @@ public: ScoringContextRef getInitialScoringContext() override; // Scoring context with step incremented by 1. - virtual ScoringContextRef extendedScoringContext(LabelScorer::Request const& request) override; + ScoringContextRef extendedScoringContext(LabelScorer::Request const& request) override; // Gets the buffered score for the requested token at the requested step std::optional computeScoreWithTime(LabelScorer::Request const& request) override; + +protected: + Speech::TimeframeIndex minActiveTimeIndex(Core::CollapsedVector const& activeContexts) const override; }; } // namespace Nn diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc index c3120a8b..b08840f3 100644 --- a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -337,22 +338,31 @@ bool LexiconfreeTimesyncBeamSearch::decodeStep() { recombination(newBeam_); numActiveHyps_ += newBeam_.size(); - if (logStepwiseStatistics_) { - clog() << Core::XmlFull("active-hyps", newBeam_.size()); + /* + * Clean up label scorer caches. + */ + Core::CollapsedVector activeContexts; + for (auto const& hyp : newBeam_) { + activeContexts.push_back(hyp.scoringContext); } + labelScorer_->cleanupCaches(activeContexts); + + /* + * Log statistics about the new beam after this step. + */ + beam_.swap(newBeam_); if (debugChannel_.isOpen()) { std::stringstream ss; - for (size_t hypIdx = 0ul; hypIdx < newBeam_.size(); ++hypIdx) { - ss << "Hypothesis " << hypIdx + 1ul << ": " << newBeam_[hypIdx].toString() << "\n"; + for (size_t hypIdx = 0ul; hypIdx < beam_.size(); ++hypIdx) { + ss << "Hypothesis " << hypIdx + 1ul << ": " << beam_[hypIdx].toString() << "\n"; } ss << "\n"; debugChannel_ << ss.str(); } - beam_.swap(newBeam_); - if (logStepwiseStatistics_) { + clog() << Core::XmlFull("active-hyps", beam_.size()); clog() << Core::XmlFull("best-hyp-score", getBestHypothesis().score); clog() << Core::XmlFull("worst-hyp-score", getWorstHypothesis().score); clog() << Core::XmlClose("search-step-stats"); From dcf65e5002fd73619393b8647895c5bfb523db1f Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Wed, 9 Apr 2025 18:56:02 +0200 Subject: [PATCH 37/52] Introduce function to get input at correct timestep in BufferedLabelScorer --- src/Nn/LabelScorer/BufferedLabelScorer.cc | 16 ++++++++++++++-- src/Nn/LabelScorer/BufferedLabelScorer.hh | 9 ++++++--- src/Nn/LabelScorer/NoOpLabelScorer.cc | 5 +++-- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/src/Nn/LabelScorer/BufferedLabelScorer.cc b/src/Nn/LabelScorer/BufferedLabelScorer.cc index cbc4348d..43d18e6f 100644 --- a/src/Nn/LabelScorer/BufferedLabelScorer.cc +++ b/src/Nn/LabelScorer/BufferedLabelScorer.cc @@ -20,9 +20,9 @@ namespace Nn { BufferedLabelScorer::BufferedLabelScorer(Core::Configuration const& config) : Core::Component(config), Precursor(config), + expectMoreFeatures_(true), inputBuffer_(), - numDeletedInputs_(0ul), - expectMoreFeatures_(true) { + numDeletedInputs_(0ul) { } void BufferedLabelScorer::reset() { @@ -52,5 +52,17 @@ void BufferedLabelScorer::cleanupCaches(Core::CollapsedVector numDeletedInputs_ += deleteInputs; } } +std::optional BufferedLabelScorer ::getInput(Speech::TimeframeIndex timeIndex) const { + if (timeIndex < numDeletedInputs_) { + error("Tried to get input feature that was already cleaned up."); + } + + size_t bufferPosition = timeIndex - numDeletedInputs_; + if (bufferPosition >= inputBuffer_.size()) { + return {}; + } + + return inputBuffer_[bufferPosition]; +} } // namespace Nn diff --git a/src/Nn/LabelScorer/BufferedLabelScorer.hh b/src/Nn/LabelScorer/BufferedLabelScorer.hh index e6d453f9..1c6d38f4 100644 --- a/src/Nn/LabelScorer/BufferedLabelScorer.hh +++ b/src/Nn/LabelScorer/BufferedLabelScorer.hh @@ -48,11 +48,14 @@ public: virtual void cleanupCaches(Core::CollapsedVector const& activeContexts) override; protected: - std::deque inputBuffer_; // Buffer that contains all the feature data for the current segment - size_t numDeletedInputs_; // Count delted inputs in order to adress the correct index in inputBuffer_ - bool expectMoreFeatures_; // Flag to record segment end signal + bool expectMoreFeatures_; // Flag to record segment end signal virtual Speech::TimeframeIndex minActiveTimeIndex(Core::CollapsedVector const& activeContexts) const = 0; + std::optional getInput(Speech::TimeframeIndex timeIndex) const; + +private: + std::deque inputBuffer_; // Buffer that contains all the feature data for the current segment + size_t numDeletedInputs_; // Count deleted inputs in order to adress the correct index in inputBuffer_ }; } // namespace Nn diff --git a/src/Nn/LabelScorer/NoOpLabelScorer.cc b/src/Nn/LabelScorer/NoOpLabelScorer.cc index 7d611c83..9b88b315 100644 --- a/src/Nn/LabelScorer/NoOpLabelScorer.cc +++ b/src/Nn/LabelScorer/NoOpLabelScorer.cc @@ -33,11 +33,12 @@ ScoringContextRef StepwiseNoOpLabelScorer::extendedScoringContext(LabelScorer::R std::optional StepwiseNoOpLabelScorer::computeScoreWithTime(LabelScorer::Request const& request) { StepScoringContextRef stepHistory(dynamic_cast(request.context.get())); - if (inputBuffer_.size() <= stepHistory->currentStep) { + auto input = getInput(stepHistory->currentStep); + if (not input) { return {}; } - return ScoreWithTime{inputBuffer_.at(stepHistory->currentStep)[request.nextToken], stepHistory->currentStep}; + return ScoreWithTime{(*input)[request.nextToken], stepHistory->currentStep}; } Speech::TimeframeIndex StepwiseNoOpLabelScorer::minActiveTimeIndex(Core::CollapsedVector const& activeContexts) const { From 2e5bc56dba1cf6fd4c22083525c9ef33fdbe9ee4 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Wed, 16 Apr 2025 18:17:13 +0200 Subject: [PATCH 38/52] Extract sub-ScoringContexts from CombineScoringContexts --- src/Nn/LabelScorer/CombineLabelScorer.cc | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/Nn/LabelScorer/CombineLabelScorer.cc b/src/Nn/LabelScorer/CombineLabelScorer.cc index fa72656b..d9c3f78f 100644 --- a/src/Nn/LabelScorer/CombineLabelScorer.cc +++ b/src/Nn/LabelScorer/CombineLabelScorer.cc @@ -72,8 +72,20 @@ ScoringContextRef CombineLabelScorer::extendedScoringContext(Request const& requ } void CombineLabelScorer::cleanupCaches(Core::CollapsedVector const& activeContexts) { - for (auto& scaledScorer : scaledScorers_) { - scaledScorer.scorer->cleanupCaches(activeContexts); + Core::CollapsedVector combineContexts; + combineContexts.reserve(activeContexts.size()); + for (auto const& activeContext : activeContexts) { + combineContexts.push_back(dynamic_cast(activeContext.get())); + } + + for (size_t scorerIdx = 0ul; scorerIdx < scaledScorers_.size(); ++scorerIdx) { + auto const& scaledScorer = scaledScorers_[scorerIdx]; + Core::CollapsedVector subScoringContexts; + for (auto const& combineContext : combineContexts) { + subScoringContexts.push_back(combineContext->scoringContexts[scorerIdx]); + } + + scaledScorer.scorer->cleanupCaches(subScoringContexts); } } From b7696e76a7c2bd870d32115042a67fdeee5e8537 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Fri, 9 May 2025 14:32:12 +0200 Subject: [PATCH 39/52] Formatting and include fixes --- src/Nn/LabelScorer/BufferedLabelScorer.cc | 2 +- src/Nn/LabelScorer/NoOpLabelScorer.cc | 1 - src/Nn/LabelScorer/NoOpLabelScorer.hh | 1 - 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/Nn/LabelScorer/BufferedLabelScorer.cc b/src/Nn/LabelScorer/BufferedLabelScorer.cc index 43d18e6f..8686fa77 100644 --- a/src/Nn/LabelScorer/BufferedLabelScorer.cc +++ b/src/Nn/LabelScorer/BufferedLabelScorer.cc @@ -52,7 +52,7 @@ void BufferedLabelScorer::cleanupCaches(Core::CollapsedVector numDeletedInputs_ += deleteInputs; } } -std::optional BufferedLabelScorer ::getInput(Speech::TimeframeIndex timeIndex) const { +std::optional BufferedLabelScorer::getInput(Speech::TimeframeIndex timeIndex) const { if (timeIndex < numDeletedInputs_) { error("Tried to get input feature that was already cleaned up."); } diff --git a/src/Nn/LabelScorer/NoOpLabelScorer.cc b/src/Nn/LabelScorer/NoOpLabelScorer.cc index 9b88b315..30c099c1 100644 --- a/src/Nn/LabelScorer/NoOpLabelScorer.cc +++ b/src/Nn/LabelScorer/NoOpLabelScorer.cc @@ -15,7 +15,6 @@ #include "NoOpLabelScorer.hh" #include "ScoringContext.hh" -#include "Speech/Types.hh" namespace Nn { diff --git a/src/Nn/LabelScorer/NoOpLabelScorer.hh b/src/Nn/LabelScorer/NoOpLabelScorer.hh index 483539b8..81de8e6b 100644 --- a/src/Nn/LabelScorer/NoOpLabelScorer.hh +++ b/src/Nn/LabelScorer/NoOpLabelScorer.hh @@ -17,7 +17,6 @@ #define NO_OP_LABEL_SCORER_HH #include "BufferedLabelScorer.hh" -#include "Speech/Types.hh" namespace Nn { From f22812b5eabfc47bc77098c2963ed96fb25d5058 Mon Sep 17 00:00:00 2001 From: Eugen Beck Date: Fri, 9 May 2025 19:31:40 +0200 Subject: [PATCH 40/52] Add handling of otherRootStates in PersistentStateTree and StaticSearchAutomaton (#125) --- src/Search/AdvancedTreeSearch/SearchSpace.cc | 3 +++ src/Search/PersistentStateTree.cc | 24 ++++++++------------ 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/Search/AdvancedTreeSearch/SearchSpace.cc b/src/Search/AdvancedTreeSearch/SearchSpace.cc index 96b87e12..602fc081 100644 --- a/src/Search/AdvancedTreeSearch/SearchSpace.cc +++ b/src/Search/AdvancedTreeSearch/SearchSpace.cc @@ -417,6 +417,9 @@ void StaticSearchAutomaton::buildDepths(bool onlyFromRoot) { invertedStateDepths.resize(network.structure.stateCount(), Core::Type::min); fillStateDepths(network.rootState, 0); fillStateDepths(network.ciRootState, 0); + for (StateId root : network.otherRootStates) { + fillStateDepths(root, 0); + } bool offsetted = false; diff --git a/src/Search/PersistentStateTree.cc b/src/Search/PersistentStateTree.cc index 9db4657c..adcbdfb7 100644 --- a/src/Search/PersistentStateTree.cc +++ b/src/Search/PersistentStateTree.cc @@ -33,7 +33,7 @@ static const Core::ParameterString paramCacheArchive( "cache archive in which the persistent state-network should be cached", "global-cache"); -static u32 formatVersion = 12; +static u32 formatVersion = 13; namespace Search { struct ConvertTree { @@ -291,18 +291,14 @@ MappedArchiveWriter& operator<<(MappedArchiveWriter& writer, const std::map> dummyIndex >> dependenciesChecksum; + in >> dependenciesChecksum; if (dependenciesChecksum != dependencies_.getChecksum()) { Core::Application::us()->log() << "dependencies of the network image don't equal the required dependencies with checksum " << dependenciesChecksum; return false; } - if (!structure.read(in)) + if (!structure.read(in)) { return false; + } in >> exits; in >> coarticulatedRootStates >> unpushedCoarticulatedRootStates >> rootTransitDescriptions; in >> pushedWordEndNodes >> uncoarticulatedWordEndStates; - in >> rootState >> ciRootState; + in >> rootState >> ciRootState >> otherRootStates; return in.good(); } @@ -364,6 +357,9 @@ void PersistentStateTree::removeOutputs() { std::set roots = coarticulatedRootStates; roots.insert(rootState); roots.insert(ciRootState); + for (StateId root : otherRootStates) { + roots.insert(root); + } // Also collect all transition-successors as coarticulated roots for (StateId node = 1; node < structure.stateCount(); ++node) { From c9e649826080f426f21d7f06e5f444459ffc9783 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Mon, 12 May 2025 15:52:29 +0200 Subject: [PATCH 41/52] Add new LexiconfreeLabelsyncBeamSearch search algorithm --- Modules.make | 9 +- .../makefiles/Modules.make | 1 + .../makefiles/Modules.make | 1 + .../makefiles/Modules.make | 1 + .../makefiles/Modules.make | 1 + .../makefiles/Modules.make | 1 + .../makefiles/Modules.make | 1 + config/cc-gcc.make | 2 +- config/os-linux.make | 50 +- src/Nn/LabelScorer/LabelScorer.hh | 1 + .../LexiconfreeLabelsyncBeamSearch.cc | 592 ++++++++++++++++++ .../LexiconfreeLabelsyncBeamSearch.hh | 184 ++++++ .../LexiconfreeLabelsyncBeamSearch/Makefile | 24 + src/Search/Makefile | 4 + 14 files changed, 837 insertions(+), 35 deletions(-) create mode 100644 src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc create mode 100644 src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.hh create mode 100644 src/Search/LexiconfreeLabelsyncBeamSearch/Makefile diff --git a/Modules.make b/Modules.make index 0f233586..9dba067e 100644 --- a/Modules.make +++ b/Modules.make @@ -33,7 +33,7 @@ MODULES += MODULE_AUDIO_RAW MODULES += MODULE_AUDIO_WAV_SYSTEM # ****** Cache Manager integration ****** -# MODULES += MODULE_CORE_CACHE_MANAGER +MODULES += MODULE_CORE_CACHE_MANAGER # ****** Cart ****** MODULES += MODULE_CART @@ -67,11 +67,11 @@ MODULES += MODULE_NN_SEQUENCE_TRAINING MODULES += MODULE_PYTHON # ****** OpenFst ****** -MODULES += MODULE_OPENFST +# MODULES += MODULE_OPENFST # ****** Search ****** MODULES += MODULE_SEARCH_MBR -MODULES += MODULE_SEARCH_WFST +# MODULES += MODULE_SEARCH_WFST MODULES += MODULE_SEARCH_LINEAR MODULES += MODULE_ADVANCED_TREE_SEARCH @@ -108,7 +108,7 @@ MODULES += MODULE_TEST MODULES += MODULE_TENSORFLOW # ONNX integration -# MODULES += MODULE_ONNX +MODULES += MODULE_ONNX # define variables for the makefiles $(foreach module, $(MODULES), $(eval $(module) = 1)) @@ -147,6 +147,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) +LIBS_SEARCH += src/Search/LexiconfreeLabelsyncBeamSearch/libSprintLexiconfreeLabelsyncBeamSearch.$(a) LIBS_SEARCH += src/Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) diff --git a/apptainer/2022-10-21_tensorflow-1.15_arm_v1/makefiles/Modules.make b/apptainer/2022-10-21_tensorflow-1.15_arm_v1/makefiles/Modules.make index dbd33406..e35588c2 100644 --- a/apptainer/2022-10-21_tensorflow-1.15_arm_v1/makefiles/Modules.make +++ b/apptainer/2022-10-21_tensorflow-1.15_arm_v1/makefiles/Modules.make @@ -142,6 +142,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) +LIBS_SEARCH += src/Search/LexiconfreeLabelsyncBeamSearch/libSprintLexiconfreeLabelsyncBeamSearch.$(a) LIBS_SEARCH += src/Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) diff --git a/apptainer/2022-10-21_tensorflow-1.15_v1/makefiles/Modules.make b/apptainer/2022-10-21_tensorflow-1.15_v1/makefiles/Modules.make index dbd33406..e35588c2 100644 --- a/apptainer/2022-10-21_tensorflow-1.15_v1/makefiles/Modules.make +++ b/apptainer/2022-10-21_tensorflow-1.15_v1/makefiles/Modules.make @@ -142,6 +142,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) +LIBS_SEARCH += src/Search/LexiconfreeLabelsyncBeamSearch/libSprintLexiconfreeLabelsyncBeamSearch.$(a) LIBS_SEARCH += src/Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) diff --git a/apptainer/2023-05-08_tensorflow-2.8_v1/makefiles/Modules.make b/apptainer/2023-05-08_tensorflow-2.8_v1/makefiles/Modules.make index e8db4f15..8a230690 100644 --- a/apptainer/2023-05-08_tensorflow-2.8_v1/makefiles/Modules.make +++ b/apptainer/2023-05-08_tensorflow-2.8_v1/makefiles/Modules.make @@ -142,6 +142,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) +LIBS_SEARCH += src/Search/LexiconfreeLabelsyncBeamSearch/libSprintLexiconfreeLabelsyncBeamSearch.$(a) LIBS_SEARCH += src/Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) diff --git a/apptainer/2023-08-09_tensorflow-2.8_onnx-1.15_v1/makefiles/Modules.make b/apptainer/2023-08-09_tensorflow-2.8_onnx-1.15_v1/makefiles/Modules.make index 906530d1..695e603d 100644 --- a/apptainer/2023-08-09_tensorflow-2.8_onnx-1.15_v1/makefiles/Modules.make +++ b/apptainer/2023-08-09_tensorflow-2.8_onnx-1.15_v1/makefiles/Modules.make @@ -146,6 +146,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) +LIBS_SEARCH += src/Search/LexiconfreeLabelsyncBeamSearch/libSprintLexiconfreeLabelsyncBeamSearch.$(a) LIBS_SEARCH += src/Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) diff --git a/apptainer/2023-11-08_tensorflow-2.14_v1/makefiles/Modules.make b/apptainer/2023-11-08_tensorflow-2.14_v1/makefiles/Modules.make index b4b900e0..9524b589 100644 --- a/apptainer/2023-11-08_tensorflow-2.14_v1/makefiles/Modules.make +++ b/apptainer/2023-11-08_tensorflow-2.14_v1/makefiles/Modules.make @@ -147,6 +147,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) +LIBS_SEARCH += src/Search/LexiconfreeLabelsyncBeamSearch/libSprintLexiconfreeLabelsyncBeamSearch.$(a) LIBS_SEARCH += src/Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) diff --git a/apptainer/2025-04-23_tensorflow-2.17_onnx-1.20_v1/makefiles/Modules.make b/apptainer/2025-04-23_tensorflow-2.17_onnx-1.20_v1/makefiles/Modules.make index f578a351..9dba067e 100644 --- a/apptainer/2025-04-23_tensorflow-2.17_onnx-1.20_v1/makefiles/Modules.make +++ b/apptainer/2025-04-23_tensorflow-2.17_onnx-1.20_v1/makefiles/Modules.make @@ -147,6 +147,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) +LIBS_SEARCH += src/Search/LexiconfreeLabelsyncBeamSearch/libSprintLexiconfreeLabelsyncBeamSearch.$(a) LIBS_SEARCH += src/Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) diff --git a/config/cc-gcc.make b/config/cc-gcc.make index ab945eb9..c221aab7 100644 --- a/config/cc-gcc.make +++ b/config/cc-gcc.make @@ -28,7 +28,7 @@ CCFLAGS += -funsigned-char CCFLAGS += -fno-exceptions CFLAGS += -std=c99 CXXFLAGS += -std=c++17 -CXXFLAGS += -Wno-unknown-pragmas -Werror=return-type +CXXFLAGS += -Wno-unknown-pragmas #CCFLAGS += -pedantic CCFLAGS += -Wall CCFLAGS += -Wno-long-long diff --git a/config/os-linux.make b/config/os-linux.make index 46bdf5c1..af4beed7 100644 --- a/config/os-linux.make +++ b/config/os-linux.make @@ -46,29 +46,21 @@ LDFLAGS += -L$(TBB_DIR)/lib -ltbb endif ifdef MODULE_TENSORFLOW -TF_COMPILE_BASE = /opt/tensorflow/tensorflow - TF_CXXFLAGS = -fexceptions -TF_CXXFLAGS += -I$(TF_COMPILE_BASE)/ -TF_CXXFLAGS += -I$(TF_COMPILE_BASE)/bazel-bin/ -TF_CXXFLAGS += -I$(TF_COMPILE_BASE)/bazel-genfiles/ -TF_CXXFLAGS += -I$(TF_COMPILE_BASE)/bazel-tensorflow/external/eigen_archive/ -TF_CXXFLAGS += -I$(TF_COMPILE_BASE)/bazel-tensorflow/external/com_google_protobuf/src/ -TF_CXXFLAGS += -I$(TF_COMPILE_BASE)/bazel-tensorflow/external/com_google_absl/ - -TF_LDFLAGS = -L$(TF_COMPILE_BASE)/bazel-bin/tensorflow -ltensorflow_cc -ltensorflow_framework -TF_LDFLAGS += -Wl,-rpath -Wl,$(TF_COMPILE_BASE)/bazel-bin/tensorflow +TF_CXXFLAGS += -I/usr/local/lib/python3.11/dist-packages/tensorflow/include +TF_LDFLAGS += -Wl,--no-as-needed -Wl,--allow-multiple-definition +TF_LDFLAGS += -lcrypto +TF_LDFLAGS += -L/usr/local/lib/python3.11/dist-packages/tensorflow +TF_LDFLAGS += -Wl,-rpath -Wl,/usr/local/lib/python3.11/dist-packages/tensorflow +TF_LDFLAGS += -l:libtensorflow_cc.so.2 -l:libtensorflow_framework.so.2 # USE_TENSORFLOW_MKL=1 endif - ifdef MODULE_ONNX LDFLAGS += -lonnxruntime -ifndef MODULE_TENSORFLOW CXXFLAGS += -fexceptions endif -endif # ----------------------------------------------------------------------------- # system Libraries @@ -108,25 +100,22 @@ INCLUDES += -I$(TF_COMPILE_BASE)/bazel-tensorflow/external/mkl_linux/include/ LDFLAGS += -lmklml_intel -liomp5 LDFLAGS += -llapack else -INCLUDES += `pkg-config --cflags blas` -INCLUDES += `pkg-config --cflags lapack` -LDFLAGS += `pkg-config --libs blas` -LDFLAGS += `pkg-config --libs lapack` +INCLUDES += -I/usr/include/openblas +LDFLAGS += -llapack -lopenblas endif endif endif ifdef MODULE_CUDA -CUDAROOT = /usr/local/cuda-7.0 +CUDAROOT = /usr/local/cuda-11.6 INCLUDES += -I$(CUDAROOT)/include/ LDFLAGS += -L$(CUDAROOT)/lib64/ -lcublas -lcudart -lcurand NVCC = $(CUDAROOT)/bin/nvcc # optimal for GTX680; set sm_35 for K20 -NVCCFLAGS = -gencode arch=compute_20,code=sm_20 \ - -gencode arch=compute_30,code=sm_30 \ - -gencode arch=compute_35,code=sm_35 \ - -gencode arch=compute_52,code=sm_52 \ - -gencode arch=compute_61,code=sm_61 +NVCCFLAGS = -gencode arch=compute_61,code=sm_61 \ # GTX 1080 + -gencode arch=compute_75,code=sm_75 \ # RTX 2080 + -gencode arch=compute_86,code=sm_86 \ # RTX 3090 + --compiler-options -fPIC endif ifeq ($(PROFILE),gprof) @@ -155,18 +144,19 @@ ifdef MODULE_PYTHON # Use --ldflags --embed for python >= 3.8 PYTHON_PATH = ifneq (${PYTHON_PATH},) -INCLUDES += `${PYTHON_PATH}/bin/python3-config --includes 2>/dev/null` -LDFLAGS += `${PYTHON_PATH}/bin/python3-config --ldflags 2>/dev/null` +INCLUDES += `${PYTHON_PATH}/bin/python3.11-config --includes 2>/dev/null` +LDFLAGS += `${PYTHON_PATH}/bin/python3.11-config --ldflags --embed 2>/dev/null` LDFLAGS += -Wl,-rpath -Wl,${PYTHON_PATH}/lib else -INCLUDES += `python3-config --includes 2>/dev/null` -INCLUDES += -I$(shell python3 -c 'import numpy as np; print(np.get_include())') -INCLUDES += `python3 -m pybind11 --includes 2>/dev/null` -LDFLAGS += `python3-config --ldflags --embed 2>/dev/null` +INCLUDES += `python3.11-config --includes 2>/dev/null` +INCLUDES += -I$(shell python3.11 -c 'import numpy as np; print(np.get_include())') +INCLUDES += `python3.11 -m pybind11 --includes 2>/dev/null` +LDFLAGS += `python3.11-config --ldflags --embed 2>/dev/null` # IF you want to use Python2 for whatever reason: # INCLUDES += `pkg-config --cflags python` # LDFLAGS += `pkg-config --libs python` endif +INCLUDES += -I$(shell python3.11 -c 'import numpy as np; print(np.get_include())') endif # X11 and QT diff --git a/src/Nn/LabelScorer/LabelScorer.hh b/src/Nn/LabelScorer/LabelScorer.hh index 6666d28d..c186964e 100644 --- a/src/Nn/LabelScorer/LabelScorer.hh +++ b/src/Nn/LabelScorer/LabelScorer.hh @@ -84,6 +84,7 @@ public: BLANK_LOOP, INITIAL_LABEL, INITIAL_BLANK, + SENTENCE_END }; // Request for scoring or context extension diff --git a/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc b/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc new file mode 100644 index 00000000..ee3fdd2e --- /dev/null +++ b/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc @@ -0,0 +1,592 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "LexiconfreeLabelsyncBeamSearch.hh" + +#include +#include + +#include +#include +#include +#include +#include + +namespace Search { + +/* + * ======================= + * === LabelHypothesis === + * ======================= + */ + +LexiconfreeLabelsyncBeamSearch::LabelHypothesis::LabelHypothesis() + : scoringContext(), + currentToken(Core::Type::max), + length(0), + score(0.0), + scaledScore(0.0), + trace(Core::ref(new LatticeTrace(0, {0, 0}, {}))), + isActive(true) {} + +LexiconfreeLabelsyncBeamSearch::LabelHypothesis::LabelHypothesis( + LexiconfreeLabelsyncBeamSearch::LabelHypothesis const& base, + LexiconfreeLabelsyncBeamSearch::ExtensionCandidate const& extension, + Nn::ScoringContextRef const& newScoringContext, + float lengthNormScale) + : scoringContext(newScoringContext), + currentToken(extension.nextToken), + length(base.length + 1), + score(extension.score), + scaledScore(score / std::pow(length, lengthNormScale)), + trace(Core::ref(new LatticeTrace( + base.trace, + extension.pron, + extension.timeframe + 1, + {extension.score, 0}, + {}))), + isActive(extension.transitionType != Nn::LabelScorer::TransitionType::SENTENCE_END) { +} + +std::string LexiconfreeLabelsyncBeamSearch::LabelHypothesis::toString() const { + std::stringstream ss; + ss << "Score: " << score << ", traceback: "; + + auto traceback = trace->performTraceback(); + + for (auto& item : *traceback) { + if (item.pronunciation and item.pronunciation->lemma()) { + ss << item.pronunciation->lemma()->symbol() << " "; + } + } + return ss.str(); +} + +/* + * ===================================== + * === LexiconfreeLabelsyncBeamSearch == + * ===================================== + */ + +const Core::ParameterInt LexiconfreeLabelsyncBeamSearch::paramMaxBeamSize( + "max-beam-size", + "Maximum number of hypotheses in the search beam.", + 1, 1); + +const Core::ParameterFloat LexiconfreeLabelsyncBeamSearch::paramScoreThreshold( + "score-threshold", + "Prune any hypotheses with a score that is at least this much worse than the best hypothesis. If not set, no score pruning will be done.", + Core::Type::max, 0); + +const Core::ParameterInt LexiconfreeLabelsyncBeamSearch::paramSentenceEndLabelIndex( + "sentence-end-index", + "Index of the sentence-end label in the lexicon. Can also be inferred from lexicon if it has a lemma with `special='blank'`. If not set, the search will not use blank.", + Core::Type::max); + +const Core::ParameterFloat LexiconfreeLabelsyncBeamSearch::paramLengthNormScale( + "length-norm-scale", + "Scaling factor for the hypothesis length normalization.", + 0.0); + +const Core::ParameterFloat LexiconfreeLabelsyncBeamSearch::paramMaxLabelsPerTimestep( + "max-labels-per-timestep", + "Maximum number of emitted labels per input timestep counted via `addInput`/`addInputs`.", + 1.0); + +const Core::ParameterBool LexiconfreeLabelsyncBeamSearch::paramLogStepwiseStatistics( + "log-stepwise-statistics", + "Log statistics about the beam at every search step.", + false); + +LexiconfreeLabelsyncBeamSearch::LexiconfreeLabelsyncBeamSearch(Core::Configuration const& config) + : Core::Component(config), + SearchAlgorithmV2(config), + maxBeamSize_(paramMaxBeamSize(config)), + scoreThreshold_(paramScoreThreshold(config)), + lengthNormScale_(paramLengthNormScale(config)), + maxLabelsPerTimestep_(paramMaxLabelsPerTimestep(config)), + sentenceEndLabelIndex_(paramSentenceEndLabelIndex(config)), + logStepwiseStatistics_(paramLogStepwiseStatistics(config)), + debugChannel_(config, "debug"), + labelScorer_(), + beam_(), + extensions_(), + newBeam_(), + requests_(), + recombinedHypotheses_(), + initializationTime_(), + featureProcessingTime_(), + scoringTime_(), + contextExtensionTime_(), + numHypsAfterScorePruning_("num-hyps-after-score-pruning"), + numHypsAfterBeamPruning_("num-hyps-after-beam-pruning"), + currentSearchStep_(0ul), + totalTimesteps_(0ul), + finishedSegment_(false) { + beam_.reserve(maxBeamSize_); + newBeam_.reserve(maxBeamSize_ * 2); // terminated + active + recombinedHypotheses_.reserve(maxBeamSize_); + + useScorePruning_ = scoreThreshold_ != Core::Type::max; + + log() << "Use sentence-end label with index " << sentenceEndLabelIndex_; +} + +Speech::ModelCombination::Mode LexiconfreeLabelsyncBeamSearch::requiredModelCombination() const { + return Speech::ModelCombination::useLabelScorer | Speech::ModelCombination::useLexicon; +} + +bool LexiconfreeLabelsyncBeamSearch::setModelCombination(Speech::ModelCombination const& modelCombination) { + lexicon_ = modelCombination.lexicon(); + labelScorer_ = modelCombination.labelScorer(); + + extensions_.reserve(maxBeamSize_ * lexicon_->nLemmas()); + requests_.reserve(extensions_.size()); + + auto sentenceEndLemma = lexicon_->specialLemma("sentence-end"); + if (!sentenceEndLemma) { + sentenceEndLemma = lexicon_->specialLemma("sentence-boundary"); + } + if (sentenceEndLemma) { + if (sentenceEndLabelIndex_ == Core::Type::max) { + sentenceEndLabelIndex_ = sentenceEndLemma->id(); + log() << "Use sentence-end index " << sentenceEndLabelIndex_ << " inferred from lexicon"; + } + else if (sentenceEndLabelIndex_ != static_cast(sentenceEndLemma->id())) { + warning() << "SentenceEnd lemma exists in lexicon with id " << sentenceEndLemma->id() << " but is overwritten by config parameter with value " << sentenceEndLabelIndex_; + } + } + + reset(); + return true; +} + +void LexiconfreeLabelsyncBeamSearch::reset() { + initializationTime_.start(); + + labelScorer_->reset(); + + // Reset beam to a single empty hypothesis + beam_.clear(); + beam_.push_back(LabelHypothesis()); + beam_.front().scoringContext = labelScorer_->getInitialScoringContext(); + + finishedSegment_ = false; + totalTimesteps_ = 0ul; + currentSearchStep_ = 0ul; + + initializationTime_.stop(); +} + +void LexiconfreeLabelsyncBeamSearch::enterSegment(Bliss::SpeechSegment const* segment) { + initializationTime_.start(); + labelScorer_->reset(); + resetStatistics(); + initializationTime_.stop(); + finishedSegment_ = false; + totalTimesteps_ = 0ul; + currentSearchStep_ = 0ul; +} + +void LexiconfreeLabelsyncBeamSearch::finishSegment() { + featureProcessingTime_.start(); + labelScorer_->signalNoMoreFeatures(); + featureProcessingTime_.stop(); + decodeManySteps(); + logStatistics(); + finishedSegment_ = true; +} + +void LexiconfreeLabelsyncBeamSearch::putFeature(Nn::DataView const& feature) { + featureProcessingTime_.start(); + labelScorer_->addInput(feature); + ++totalTimesteps_; + featureProcessingTime_.stop(); +} + +void LexiconfreeLabelsyncBeamSearch::putFeatures(Nn::DataView const& features, size_t nTimesteps) { + featureProcessingTime_.start(); + labelScorer_->addInputs(features, nTimesteps); + totalTimesteps_ += nTimesteps; + featureProcessingTime_.stop(); +} + +Core::Ref LexiconfreeLabelsyncBeamSearch::getCurrentBestTraceback() const { + return getBestHypothesis().trace->performTraceback(); +} + +Core::Ref LexiconfreeLabelsyncBeamSearch::getCurrentBestWordLattice() const { + auto& bestHypothesis = getBestHypothesis(); + + LatticeTrace endTrace(bestHypothesis.trace, 0, bestHypothesis.trace->time + 1, bestHypothesis.trace->score, {}); + + for (auto const& hyp : beam_) { + if (hyp.isActive != bestHypothesis.isActive) { + continue; + } + auto siblingTrace = Core::ref(new LatticeTrace(hyp.trace, 0, hyp.trace->time, hyp.trace->score, {})); + endTrace.appendSiblingToChain(siblingTrace); + } + + return endTrace.buildWordLattice(lexicon_); +} + +bool LexiconfreeLabelsyncBeamSearch::decodeStep() { + if (finishedSegment_) { + return false; + } + if (currentSearchStep_ >= maxLabelsPerTimestep_ * std::max(totalTimesteps_, 1ul)) { + warning() << "Terminated search due to reaching max number of labels"; + return false; + } + + // Assume the output labels are stored as lexicon lemma orth and ordered consistently with NN output index + auto lemmas = lexicon_->lemmas(); + + /* + * Collect all possible extensions for all hypotheses in the beam. + * Also Create scoring requests for the label scorer. + * Each extension candidate makes up a request. + */ + extensions_.clear(); + requests_.clear(); + + for (size_t hypIndex = 0ul; hypIndex < beam_.size(); ++hypIndex) { + auto& hyp = beam_[hypIndex]; + + if (not hyp.isActive) { + continue; + } + + // Iterate over possible successors (all lemmas) + for (auto lemmaIt = lemmas.first; lemmaIt != lemmas.second; ++lemmaIt) { + const Bliss::Lemma* lemma(*lemmaIt); + Nn::LabelIndex tokenIdx = lemma->id(); + + auto transitionType = Nn::LabelScorer::TransitionType::LABEL_TO_LABEL; + if (hyp.currentToken == Core::Type::max) { + transitionType = Nn::LabelScorer::TransitionType::INITIAL_LABEL; + } + if (tokenIdx == sentenceEndLabelIndex_) { + transitionType = Nn::LabelScorer::TransitionType::SENTENCE_END; + } + + extensions_.push_back( + {tokenIdx, + lemma->pronunciations().first, + hyp.score, + 0, + transitionType, + hypIndex}); + requests_.push_back({beam_[hypIndex].scoringContext, tokenIdx, transitionType}); + } + } + + if (requests_.empty()) { + // All hypotheses are terminated -> no search step can be made. + return false; + } + + /* + * Perform scoring of all the requests with the label scorer. + */ + scoringTime_.start(); + auto result = labelScorer_->computeScoresWithTimes(requests_); + scoringTime_.stop(); + + if (not result) { + // LabelScorer could not compute scores -> no search step can be made. + return false; + } + + for (size_t extensionIdx = 0ul; extensionIdx < extensions_.size(); ++extensionIdx) { + extensions_[extensionIdx].score += result->scores[extensionIdx]; + extensions_[extensionIdx].timeframe = result->timeframes[extensionIdx]; + } + + if (logStepwiseStatistics_) { + clog() << Core::XmlOpen("search-step-stats"); + } + + /* + * Prune set of possible extensions by max beam size and possibly also by score. + */ + + if (useScorePruning_) { + scorePruningExtensions(); + } + + beamSizePruningExtensions(); + + /* + * Create new beam from surviving extensions. + */ + newBeam_.clear(); + + for (auto const& hyp : beam_) { + if (not hyp.isActive) { + newBeam_.push_back(hyp); + } + } + + for (auto const& extension : extensions_) { + auto const& baseHyp = beam_[extension.baseHypIndex]; + + auto newScoringContext = labelScorer_->extendedScoringContext( + {baseHyp.scoringContext, + extension.nextToken, + extension.transitionType}); + newBeam_.push_back({baseHyp, extension, newScoringContext, lengthNormScale_}); + } + + /* + * For all hypotheses with the same scoring context keep only the best since they will + * all develop in the same way. + */ + recombination(); + + /* + * Prune terminated hypotheses among each other + */ + if (useScorePruning_) { + scorePruning(); + + numHypsAfterScorePruning_ += beam_.size(); + + if (logStepwiseStatistics_) { + clog() << Core::XmlFull("num-hyps-after-score-pruning", beam_.size()); + } + } + + beamSizePruning(); + numHypsAfterBeamPruning_ += beam_.size(); + if (logStepwiseStatistics_) { + clog() << Core::XmlFull("num-hyps-after-beam-pruning", beam_.size()); + } + + /* + * Clean up label scorer caches. + */ + Core::CollapsedVector activeContexts; + for (auto const& hyp : newBeam_) { + activeContexts.push_back(hyp.scoringContext); + } + labelScorer_->cleanupCaches(activeContexts); + + /* + * Log statistics about the new beam after this step. + */ + beam_.swap(newBeam_); + + if (debugChannel_.isOpen()) { + std::stringstream ssActive; + std::stringstream ssTerminated; + for (size_t hypIdx = 0ul; hypIdx < beam_.size(); ++hypIdx) { + auto const& hyp = beam_[hypIdx]; + if (hyp.isActive) { + ssActive << "Active hypothesis " << hypIdx + 1ul << ": " << beam_[hypIdx].toString() << "\n"; + } + else { + ssTerminated << "Terminated hypothesis " << hypIdx + 1ul << ": " << beam_[hypIdx].toString() << "\n"; + } + } + ssActive << "\n"; + ssTerminated << "\n"; + debugChannel_ << ssActive.str() << ssTerminated.str(); + } + + if (logStepwiseStatistics_) { + size_t numActive = std::accumulate( + beam_.begin(), + beam_.end(), + 0ul, + [](size_t acc, auto const& hyp) { return acc + static_cast(hyp.isActive); }); + auto const& bestHyp = getBestHypothesis(); + auto const& worstHyp = getWorstHypothesis(); + clog() << Core::XmlFull("active-hyps", numActive); + clog() << Core::XmlFull("terminated-hyps", beam_.size() - numActive); + clog() << Core::XmlFull("best-hyp-score", bestHyp.score); + clog() << Core::XmlFull("worst-hyp-score", worstHyp.score); + clog() << Core::XmlFull("best-hyp-normed-score", bestHyp.scaledScore); + clog() << Core::XmlFull("worst-hyp-normed-score", worstHyp.scaledScore); + clog() << Core::XmlClose("search-step-stats"); + } + + ++currentSearchStep_; + return true; +} + +LexiconfreeLabelsyncBeamSearch::LabelHypothesis const& LexiconfreeLabelsyncBeamSearch::getBestHypothesis() const { + LabelHypothesis const* bestActive = nullptr; + LabelHypothesis const* bestTerminated = nullptr; + + for (auto const& hyp : beam_) { + if (hyp.isActive) { + if (not bestActive or hyp < *bestActive) { + bestActive = &hyp; + } + } + else { + if (not bestTerminated or hyp < *bestTerminated) { + bestTerminated = &hyp; + } + } + } + + if (bestTerminated) { + return *bestTerminated; + } + else { + return *bestActive; + } +} + +LexiconfreeLabelsyncBeamSearch::LabelHypothesis const& LexiconfreeLabelsyncBeamSearch::getWorstHypothesis() const { + LabelHypothesis const* worstActive = nullptr; + LabelHypothesis const* worstTerminated = nullptr; + + for (auto const& hyp : beam_) { + if (hyp.isActive) { + if (not worstActive or hyp > *worstActive) { + worstActive = &hyp; + } + } + else { + if (not worstTerminated or hyp > *worstTerminated) { + worstTerminated = &hyp; + } + } + } + + if (worstTerminated) { + return *worstTerminated; + } + else { + return *worstActive; + } +} + +void LexiconfreeLabelsyncBeamSearch::resetStatistics() { + initializationTime_.reset(); + featureProcessingTime_.reset(); + scoringTime_.reset(); + contextExtensionTime_.reset(); + numHypsAfterScorePruning_.clear(); + numHypsAfterBeamPruning_.clear(); +} + +void LexiconfreeLabelsyncBeamSearch::logStatistics() const { + clog() << Core::XmlOpen("timing-statistics") + Core::XmlAttribute("unit", "milliseconds"); + clog() << Core::XmlOpen("initialization-time") << initializationTime_.elapsedMilliseconds() << Core::XmlClose("initialization-time"); + clog() << Core::XmlOpen("feature-processing-time") << featureProcessingTime_.elapsedMilliseconds() << Core::XmlClose("feature-processing-time"); + clog() << Core::XmlOpen("scoring-time") << scoringTime_.elapsedMilliseconds() << Core::XmlClose("scoring-time"); + clog() << Core::XmlOpen("context-extension-time") << contextExtensionTime_.elapsedMilliseconds() << Core::XmlClose("context-extension-time"); + clog() << Core::XmlClose("timing-statistics"); + numHypsAfterScorePruning_.write(clog()); + numHypsAfterBeamPruning_.write(clog()); +} + +void LexiconfreeLabelsyncBeamSearch::beamSizePruningExtensions() { + if (extensions_.size() <= maxBeamSize_) { + return; + } + + // Reorder the hypotheses by associated score value such that the first `beamSizeActive_` elements are the best + std::nth_element(extensions_.begin(), extensions_.begin() + maxBeamSize_, extensions_.end()); + extensions_.resize(maxBeamSize_); // Get rid of excessive elements +} + +void LexiconfreeLabelsyncBeamSearch::beamSizePruning() { + if (beam_.size() <= maxBeamSize_) { + return; + } + + // Reorder the hypotheses by associated score value such that the first `beamSizeTerminated_` elements are the best + std::nth_element(beam_.begin(), beam_.begin() + maxBeamSize_, beam_.end()); + beam_.resize(maxBeamSize_); // Get rid of excessive elements +} + +void LexiconfreeLabelsyncBeamSearch::scorePruningExtensions() { + if (extensions_.empty()) { + return; + } + + // Compute the pruning threshold + auto bestScore = std::min_element(extensions_.begin(), extensions_.end())->score; + auto pruningThreshold = bestScore + scoreThreshold_; + + // Remove elements with score > pruningThreshold + extensions_.erase( + std::remove_if( + extensions_.begin(), + extensions_.end(), + [&](auto const& ext) { return ext.score > pruningThreshold; }), + extensions_.end()); +} + +void LexiconfreeLabelsyncBeamSearch::scorePruning() { + if (beam_.empty()) { + return; + } + + // Compute the pruning threshold + auto bestHyp = *std::min_element( + beam_.begin(), + beam_.end()); + + // Remove elements with score > pruningThreshold + auto pruningThreshold = (bestHyp.score + scoreThreshold_) / std::pow(bestHyp.length, lengthNormScale_); + beam_.erase( + std::remove_if( + beam_.begin(), + beam_.end(), + [&](auto const& hyp) { return hyp.scaledScore > pruningThreshold; }), + beam_.end()); +} + +void LexiconfreeLabelsyncBeamSearch::recombination() { + recombinedHypotheses_.clear(); + // Map each unique ScoringContext in newHypotheses to its hypothesis + std::unordered_map seenScoringContexts; + for (auto const& hyp : newBeam_) { + // Use try_emplace to check if the scoring context already exists and create a new entry if not at the same time + auto [it, inserted] = seenScoringContexts.try_emplace(hyp.scoringContext, nullptr); + + if (inserted) { + // First time seeing this scoring context so move it over to `newHypotheses` + recombinedHypotheses_.push_back(std::move(hyp)); + it->second = &recombinedHypotheses_.back(); + } + else { + verify(not hyp.trace->sibling); + + auto* existingHyp = it->second; + if (hyp < *existingHyp) { + // New hyp is better -> replace in `newHypotheses` and add existing one as sibling + hyp.trace->sibling = existingHyp->trace; + *existingHyp = std::move(hyp); // Overwrite in-place + } + else { + // New hyp is worse -> add to existing one as sibling + hyp.trace->sibling = existingHyp->trace->sibling; + existingHyp->trace->sibling = hyp.trace; + } + } + } + + newBeam_.swap(recombinedHypotheses_); +} + +} // namespace Search diff --git a/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.hh b/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.hh new file mode 100644 index 00000000..99a245ce --- /dev/null +++ b/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.hh @@ -0,0 +1,184 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LEXICONFREE_LABELSYNC_BEAM_SEARCH_HH +#define LEXICONFREE_LABELSYNC_BEAM_SEARCH_HH + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace Search { + +/* + * Simple label synchronous beam search algorithm without pronunciation lexicon, word-level LM or transition model. + * Uses a sentence-end symbol to terminate hypotheses. + * Main purpose is open vocabulary search with AED (or similar) models. + * Supports global pruning by max beam-size and by score difference to the best hypothesis. + * Uses a LabelScorer to context initialization/extension and scoring. + * + * The search requires a lexicon that represents the vocabulary. Each lemma is viewed as a token with its index + * in the lexicon corresponding to the associated output index of the label scorer. + */ +class LexiconfreeLabelsyncBeamSearch : public SearchAlgorithmV2 { +protected: + /* + * Possible extension for some label hypothesis in the beam + */ + struct ExtensionCandidate { + Nn::LabelIndex nextToken; // Proposed token to extend the hypothesis with + const Bliss::LemmaPronunciation* pron; // Pronunciation of lemma corresponding to `nextToken` for traceback + Score score; // Would-be score of full hypothesis after extension + Search::TimeframeIndex timeframe; // Timestamp of `nextToken` for traceback + Nn::LabelScorer::TransitionType transitionType; // Type of transition toward `nextToken` + size_t baseHypIndex; // Index of base hypothesis in global beam + + bool operator<(ExtensionCandidate const& other) const { + return score < other.score; + } + }; + + /* + * Struct containing all information about a single hypothesis in the beam + */ + struct LabelHypothesis { + Nn::ScoringContextRef scoringContext; // Context to compute scores based on this hypothesis + Nn::LabelIndex currentToken; // Most recent token in associated label sequence (useful to infer transition type) + size_t length; // Number of tokens in hypothesis for length normalization + Score score; // Full score of hypothesis + Score scaledScore; // Length-normalized score of hypothesis + Core::Ref trace; // Associated trace for traceback or lattice building off of hypothesis + bool isActive; // Indicates whether the hypothesis has not produced a sentence-end label yet + + LabelHypothesis(); + LabelHypothesis(LabelHypothesis const& base, ExtensionCandidate const& extension, Nn::ScoringContextRef const& newScoringContext, float lengthNormScale); + + bool operator<(LabelHypothesis const& other) const { + return scaledScore < other.scaledScore; + } + + bool operator>(LabelHypothesis const& other) const { + return scaledScore > other.scaledScore; + } + + /* + * Get string representation for debugging. + */ + std::string toString() const; + }; + +public: + static const Core::ParameterInt paramMaxBeamSize; + static const Core::ParameterFloat paramScoreThreshold; + + static const Core::ParameterInt paramSentenceEndLabelIndex; + static const Core::ParameterFloat paramLengthNormScale; + static const Core::ParameterFloat paramMaxLabelsPerTimestep; + static const Core::ParameterBool paramLogStepwiseStatistics; + + LexiconfreeLabelsyncBeamSearch(Core::Configuration const&); + + // Inherited methods from `SearchAlgorithmV2` + + Speech::ModelCombination::Mode requiredModelCombination() const override; + bool setModelCombination(Speech::ModelCombination const& modelCombination) override; + void reset() override; + void enterSegment(Bliss::SpeechSegment const* = nullptr) override; + void finishSegment() override; + void putFeature(Nn::DataView const& feature) override; + void putFeatures(Nn::DataView const& features, size_t nTimesteps) override; + Core::Ref getCurrentBestTraceback() const override; + Core::Ref getCurrentBestWordLattice() const override; + bool decodeStep() override; + +private: + size_t maxBeamSize_; + + bool useScorePruning_; + Score scoreThreshold_; + + float lengthNormScale_; + + float maxLabelsPerTimestep_; + + Nn::LabelIndex sentenceEndLabelIndex_; + + bool logStepwiseStatistics_; + + Core::Channel debugChannel_; + + Core::Ref labelScorer_; + Bliss::LexiconRef lexicon_; + std::vector beam_; + + // Pre-allocated intermediate vectors + std::vector extensions_; + std::vector newBeam_; + std::vector requests_; + std::vector recombinedHypotheses_; + + Core::StopWatch initializationTime_; + Core::StopWatch featureProcessingTime_; + Core::StopWatch scoringTime_; + Core::StopWatch contextExtensionTime_; + + Core::Statistics numHypsAfterScorePruning_; + Core::Statistics numHypsAfterBeamPruning_; + + size_t currentSearchStep_; + size_t totalTimesteps_; + bool finishedSegment_; + + LabelHypothesis const& getBestHypothesis() const; + LabelHypothesis const& getWorstHypothesis() const; + + void resetStatistics(); + void logStatistics() const; + + /* + * Helper function for pruning of extensions to maxBeamSize_ + */ + void beamSizePruningExtensions(); + + /* + * Helper function for pruning of hyps to maxBeamSize_ + */ + void beamSizePruning(); + + /* + * Helper function for pruning of extensions to scoreThreshold__ + */ + void scorePruningExtensions(); + + /* + * Helper function for pruning of hyps to scoreThreshold_ + */ + void scorePruning(); + + /* + * Helper function for recombination of hypotheses with the same scoring context + */ + void recombination(); +}; + +} // namespace Search + +#endif // LEXICONFREE_LABELSYNC_BEAM_SEARCH_HH diff --git a/src/Search/LexiconfreeLabelsyncBeamSearch/Makefile b/src/Search/LexiconfreeLabelsyncBeamSearch/Makefile new file mode 100644 index 00000000..2db32eb8 --- /dev/null +++ b/src/Search/LexiconfreeLabelsyncBeamSearch/Makefile @@ -0,0 +1,24 @@ +#!gmake + +TOPDIR = ../../.. + +include $(TOPDIR)/Makefile.cfg + +# ----------------------------------------------------------------------------- + +SUBDIRS = +TARGETS = libSprintLexiconfreeLabelsyncBeamSearch.$(a) + +LIBSPRINTLEXICONFREELABELSYNCBEAMSEARCH_O = $(OBJDIR)/LexiconfreeLabelsyncBeamSearch.o + + +# ----------------------------------------------------------------------------- + +all: $(TARGETS) + +libSprintLexiconfreeLabelsyncBeamSearch.$(a): $(LIBSPRINTLEXICONFREELABELSYNCBEAMSEARCH_O) + $(MAKELIB) $@ $^ + +include $(TOPDIR)/Rules.make + +sinclude $(LIBSPRINTLEXICONFREELABELSYNCBEAMSEARCH_O:.o=.d) diff --git a/src/Search/Makefile b/src/Search/Makefile index efa49996..f6fc97ec 100644 --- a/src/Search/Makefile +++ b/src/Search/Makefile @@ -34,6 +34,7 @@ LIBSPRINTSEARCH_O += $(OBJDIR)/MinimumBayesRiskAStarSearch.o LIBSPRINTSEARCH_O += $(OBJDIR)/MinimumBayesRiskNBestListSearch.o LIBSPRINTSEARCH_O += $(OBJDIR)/MinimumBayesRiskSearchUtil.o endif +SUBDIRS += LexiconfreeLabelsyncBeamSearch SUBDIRS += LexiconfreeTimesyncBeamSearch ifdef MODULE_SEARCH_WFST SUBDIRS += Wfst @@ -64,6 +65,9 @@ Wfst: AdvancedTreeSearch: $(MAKE) -C $@ libSprintAdvancedTreeSearch.$(a) + +LexiconfreeLabelsyncBeamSearch: + $(MAKE) -C $@ libSprintLexiconfreeLabelsyncBeamSearch.$(a) LexiconfreeTimesyncBeamSearch: $(MAKE) -C $@ libSprintLexiconfreeTimesyncBeamSearch.$(a) From 556d5d500e4101a7bf34d6168a6a6ac31a7f6d87 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Mon, 12 May 2025 15:55:15 +0200 Subject: [PATCH 42/52] Revert Makefile changes that were produced by configure script --- Modules.make | 8 +++---- config/cc-gcc.make | 2 +- config/os-linux.make | 50 ++++++++++++++++++++++++++------------------ 3 files changed, 35 insertions(+), 25 deletions(-) diff --git a/Modules.make b/Modules.make index 9dba067e..679fc9f8 100644 --- a/Modules.make +++ b/Modules.make @@ -33,7 +33,7 @@ MODULES += MODULE_AUDIO_RAW MODULES += MODULE_AUDIO_WAV_SYSTEM # ****** Cache Manager integration ****** -MODULES += MODULE_CORE_CACHE_MANAGER +# MODULES += MODULE_CORE_CACHE_MANAGER # ****** Cart ****** MODULES += MODULE_CART @@ -67,11 +67,11 @@ MODULES += MODULE_NN_SEQUENCE_TRAINING MODULES += MODULE_PYTHON # ****** OpenFst ****** -# MODULES += MODULE_OPENFST +MODULES += MODULE_OPENFST # ****** Search ****** MODULES += MODULE_SEARCH_MBR -# MODULES += MODULE_SEARCH_WFST +MODULES += MODULE_SEARCH_WFST MODULES += MODULE_SEARCH_LINEAR MODULES += MODULE_ADVANCED_TREE_SEARCH @@ -108,7 +108,7 @@ MODULES += MODULE_TEST MODULES += MODULE_TENSORFLOW # ONNX integration -MODULES += MODULE_ONNX +# MODULES += MODULE_ONNX # define variables for the makefiles $(foreach module, $(MODULES), $(eval $(module) = 1)) diff --git a/config/cc-gcc.make b/config/cc-gcc.make index c221aab7..ab945eb9 100644 --- a/config/cc-gcc.make +++ b/config/cc-gcc.make @@ -28,7 +28,7 @@ CCFLAGS += -funsigned-char CCFLAGS += -fno-exceptions CFLAGS += -std=c99 CXXFLAGS += -std=c++17 -CXXFLAGS += -Wno-unknown-pragmas +CXXFLAGS += -Wno-unknown-pragmas -Werror=return-type #CCFLAGS += -pedantic CCFLAGS += -Wall CCFLAGS += -Wno-long-long diff --git a/config/os-linux.make b/config/os-linux.make index af4beed7..46bdf5c1 100644 --- a/config/os-linux.make +++ b/config/os-linux.make @@ -46,21 +46,29 @@ LDFLAGS += -L$(TBB_DIR)/lib -ltbb endif ifdef MODULE_TENSORFLOW +TF_COMPILE_BASE = /opt/tensorflow/tensorflow + TF_CXXFLAGS = -fexceptions -TF_CXXFLAGS += -I/usr/local/lib/python3.11/dist-packages/tensorflow/include -TF_LDFLAGS += -Wl,--no-as-needed -Wl,--allow-multiple-definition -TF_LDFLAGS += -lcrypto -TF_LDFLAGS += -L/usr/local/lib/python3.11/dist-packages/tensorflow -TF_LDFLAGS += -Wl,-rpath -Wl,/usr/local/lib/python3.11/dist-packages/tensorflow -TF_LDFLAGS += -l:libtensorflow_cc.so.2 -l:libtensorflow_framework.so.2 +TF_CXXFLAGS += -I$(TF_COMPILE_BASE)/ +TF_CXXFLAGS += -I$(TF_COMPILE_BASE)/bazel-bin/ +TF_CXXFLAGS += -I$(TF_COMPILE_BASE)/bazel-genfiles/ +TF_CXXFLAGS += -I$(TF_COMPILE_BASE)/bazel-tensorflow/external/eigen_archive/ +TF_CXXFLAGS += -I$(TF_COMPILE_BASE)/bazel-tensorflow/external/com_google_protobuf/src/ +TF_CXXFLAGS += -I$(TF_COMPILE_BASE)/bazel-tensorflow/external/com_google_absl/ + +TF_LDFLAGS = -L$(TF_COMPILE_BASE)/bazel-bin/tensorflow -ltensorflow_cc -ltensorflow_framework +TF_LDFLAGS += -Wl,-rpath -Wl,$(TF_COMPILE_BASE)/bazel-bin/tensorflow # USE_TENSORFLOW_MKL=1 endif + ifdef MODULE_ONNX LDFLAGS += -lonnxruntime +ifndef MODULE_TENSORFLOW CXXFLAGS += -fexceptions endif +endif # ----------------------------------------------------------------------------- # system Libraries @@ -100,22 +108,25 @@ INCLUDES += -I$(TF_COMPILE_BASE)/bazel-tensorflow/external/mkl_linux/include/ LDFLAGS += -lmklml_intel -liomp5 LDFLAGS += -llapack else -INCLUDES += -I/usr/include/openblas -LDFLAGS += -llapack -lopenblas +INCLUDES += `pkg-config --cflags blas` +INCLUDES += `pkg-config --cflags lapack` +LDFLAGS += `pkg-config --libs blas` +LDFLAGS += `pkg-config --libs lapack` endif endif endif ifdef MODULE_CUDA -CUDAROOT = /usr/local/cuda-11.6 +CUDAROOT = /usr/local/cuda-7.0 INCLUDES += -I$(CUDAROOT)/include/ LDFLAGS += -L$(CUDAROOT)/lib64/ -lcublas -lcudart -lcurand NVCC = $(CUDAROOT)/bin/nvcc # optimal for GTX680; set sm_35 for K20 -NVCCFLAGS = -gencode arch=compute_61,code=sm_61 \ # GTX 1080 - -gencode arch=compute_75,code=sm_75 \ # RTX 2080 - -gencode arch=compute_86,code=sm_86 \ # RTX 3090 - --compiler-options -fPIC +NVCCFLAGS = -gencode arch=compute_20,code=sm_20 \ + -gencode arch=compute_30,code=sm_30 \ + -gencode arch=compute_35,code=sm_35 \ + -gencode arch=compute_52,code=sm_52 \ + -gencode arch=compute_61,code=sm_61 endif ifeq ($(PROFILE),gprof) @@ -144,19 +155,18 @@ ifdef MODULE_PYTHON # Use --ldflags --embed for python >= 3.8 PYTHON_PATH = ifneq (${PYTHON_PATH},) -INCLUDES += `${PYTHON_PATH}/bin/python3.11-config --includes 2>/dev/null` -LDFLAGS += `${PYTHON_PATH}/bin/python3.11-config --ldflags --embed 2>/dev/null` +INCLUDES += `${PYTHON_PATH}/bin/python3-config --includes 2>/dev/null` +LDFLAGS += `${PYTHON_PATH}/bin/python3-config --ldflags 2>/dev/null` LDFLAGS += -Wl,-rpath -Wl,${PYTHON_PATH}/lib else -INCLUDES += `python3.11-config --includes 2>/dev/null` -INCLUDES += -I$(shell python3.11 -c 'import numpy as np; print(np.get_include())') -INCLUDES += `python3.11 -m pybind11 --includes 2>/dev/null` -LDFLAGS += `python3.11-config --ldflags --embed 2>/dev/null` +INCLUDES += `python3-config --includes 2>/dev/null` +INCLUDES += -I$(shell python3 -c 'import numpy as np; print(np.get_include())') +INCLUDES += `python3 -m pybind11 --includes 2>/dev/null` +LDFLAGS += `python3-config --ldflags --embed 2>/dev/null` # IF you want to use Python2 for whatever reason: # INCLUDES += `pkg-config --cflags python` # LDFLAGS += `pkg-config --libs python` endif -INCLUDES += -I$(shell python3.11 -c 'import numpy as np; print(np.get_include())') endif # X11 and QT From 6bf2d0ae0a00cf2326b22f96d844a1e6a498240f Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Mon, 12 May 2025 16:16:05 +0200 Subject: [PATCH 43/52] Improve docstrings --- .../LexiconfreeLabelsyncBeamSearch.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc b/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc index ee3fdd2e..f027a776 100644 --- a/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc +++ b/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc @@ -87,17 +87,19 @@ const Core::ParameterInt LexiconfreeLabelsyncBeamSearch::paramMaxBeamSize( const Core::ParameterFloat LexiconfreeLabelsyncBeamSearch::paramScoreThreshold( "score-threshold", - "Prune any hypotheses with a score that is at least this much worse than the best hypothesis. If not set, no score pruning will be done.", + "Prune any hypotheses with a score that is at least this much worse than the best hypothesis." + "If length normalization is enabled, the score threshold is added to the raw score before normalization." + "If not set, no score pruning will be done.", Core::Type::max, 0); const Core::ParameterInt LexiconfreeLabelsyncBeamSearch::paramSentenceEndLabelIndex( "sentence-end-index", - "Index of the sentence-end label in the lexicon. Can also be inferred from lexicon if it has a lemma with `special='blank'`. If not set, the search will not use blank.", - Core::Type::max); + "Index of the sentence-end label in the lexicon." + "Can also be inferred from lexicon if it has a lemma with `special='sentence-end'` or `special='sentence-boundary'`"); const Core::ParameterFloat LexiconfreeLabelsyncBeamSearch::paramLengthNormScale( "length-norm-scale", - "Scaling factor for the hypothesis length normalization.", + "Exponent of length for the hypothesis length normalization. Scaled scores are computed as score / length^length_norm_scale.", 0.0); const Core::ParameterFloat LexiconfreeLabelsyncBeamSearch::paramMaxLabelsPerTimestep( From 62c7b8c2ec5349349d007fa8e6114befb6e5b487 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Mon, 19 May 2025 20:01:41 +0200 Subject: [PATCH 44/52] Bugfix: Do pruning on newBeam_ instead of beam_ --- .../LexiconfreeLabelsyncBeamSearch.cc | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc b/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc index f027a776..281725bd 100644 --- a/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc +++ b/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc @@ -365,17 +365,17 @@ bool LexiconfreeLabelsyncBeamSearch::decodeStep() { if (useScorePruning_) { scorePruning(); - numHypsAfterScorePruning_ += beam_.size(); + numHypsAfterScorePruning_ += newBeam_.size(); if (logStepwiseStatistics_) { - clog() << Core::XmlFull("num-hyps-after-score-pruning", beam_.size()); + clog() << Core::XmlFull("num-hyps-after-score-pruning", newBeam_.size()); } } beamSizePruning(); - numHypsAfterBeamPruning_ += beam_.size(); + numHypsAfterBeamPruning_ += newBeam_.size(); if (logStepwiseStatistics_) { - clog() << Core::XmlFull("num-hyps-after-beam-pruning", beam_.size()); + clog() << Core::XmlFull("num-hyps-after-beam-pruning", newBeam_.size()); } /* @@ -511,13 +511,13 @@ void LexiconfreeLabelsyncBeamSearch::beamSizePruningExtensions() { } void LexiconfreeLabelsyncBeamSearch::beamSizePruning() { - if (beam_.size() <= maxBeamSize_) { + if (newBeam_.size() <= maxBeamSize_) { return; } // Reorder the hypotheses by associated score value such that the first `beamSizeTerminated_` elements are the best - std::nth_element(beam_.begin(), beam_.begin() + maxBeamSize_, beam_.end()); - beam_.resize(maxBeamSize_); // Get rid of excessive elements + std::nth_element(newBeam_.begin(), newBeam_.begin() + maxBeamSize_, newBeam_.end()); + newBeam_.resize(maxBeamSize_); // Get rid of excessive elements } void LexiconfreeLabelsyncBeamSearch::scorePruningExtensions() { @@ -539,23 +539,23 @@ void LexiconfreeLabelsyncBeamSearch::scorePruningExtensions() { } void LexiconfreeLabelsyncBeamSearch::scorePruning() { - if (beam_.empty()) { + if (newBeam_.empty()) { return; } // Compute the pruning threshold auto bestHyp = *std::min_element( - beam_.begin(), - beam_.end()); + newBeam_.begin(), + newBeam_.end()); // Remove elements with score > pruningThreshold auto pruningThreshold = (bestHyp.score + scoreThreshold_) / std::pow(bestHyp.length, lengthNormScale_); - beam_.erase( + newBeam_.erase( std::remove_if( - beam_.begin(), - beam_.end(), + newBeam_.begin(), + newBeam_.end(), [&](auto const& hyp) { return hyp.scaledScore > pruningThreshold; }), - beam_.end()); + newBeam_.end()); } void LexiconfreeLabelsyncBeamSearch::recombination() { From cfe36246a59037df63a46e83327d526871cd7027 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Mon, 19 May 2025 22:22:12 +0200 Subject: [PATCH 45/52] Fix comment string --- .../LexiconfreeLabelsyncBeamSearch.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc b/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc index 281725bd..278cf92e 100644 --- a/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc +++ b/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc @@ -360,7 +360,7 @@ bool LexiconfreeLabelsyncBeamSearch::decodeStep() { recombination(); /* - * Prune terminated hypotheses among each other + * Jointly prune terminated and active hypotheses */ if (useScorePruning_) { scorePruning(); From 872b0ff9413fb85ef03dcd5939ec63949d667458 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Wed, 28 May 2025 17:57:22 +0200 Subject: [PATCH 46/52] Apply suggestions from code review + more verbose step logging --- .../LexiconfreeLabelsyncBeamSearch.cc | 195 ++++++++++++------ .../LexiconfreeLabelsyncBeamSearch.hh | 14 +- .../LexiconfreeTimesyncBeamSearch.cc | 2 + src/Search/Module.cc | 5 + src/Search/Module.hh | 4 +- 5 files changed, 153 insertions(+), 67 deletions(-) diff --git a/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc b/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc index 278cf92e..813b6f24 100644 --- a/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc +++ b/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc @@ -132,14 +132,16 @@ LexiconfreeLabelsyncBeamSearch::LexiconfreeLabelsyncBeamSearch(Core::Configurati featureProcessingTime_(), scoringTime_(), contextExtensionTime_(), - numHypsAfterScorePruning_("num-hyps-after-score-pruning"), - numHypsAfterBeamPruning_("num-hyps-after-beam-pruning"), + numTerminatedHypsAfterScorePruning_("num-termianted-hyps-after-score-pruning"), + numTerminatedHypsAfterBeamPruning_("num-terminated-hyps-after-beam-pruning"), + numActiveHypsAfterScorePruning_("num-active-hyps-after-score-pruning"), + numActiveHypsAfterBeamPruning_("num-active-hyps-after-beam-pruning"), currentSearchStep_(0ul), totalTimesteps_(0ul), finishedSegment_(false) { beam_.reserve(maxBeamSize_); - newBeam_.reserve(maxBeamSize_ * 2); // terminated + active - recombinedHypotheses_.reserve(maxBeamSize_); + newBeam_.reserve(maxBeamSize_ * 2); // terminated + active + recombinedHypotheses_.reserve(maxBeamSize_ * 2); // terminated + active useScorePruning_ = scoreThreshold_ != Core::Type::max; @@ -249,8 +251,9 @@ bool LexiconfreeLabelsyncBeamSearch::decodeStep() { if (finishedSegment_) { return false; } - if (currentSearchStep_ >= maxLabelsPerTimestep_ * std::max(totalTimesteps_, 1ul)) { - warning() << "Terminated search due to reaching max number of labels"; + if (currentSearchStep_ >= maxLabelsPerTimestep_ * totalTimesteps_) { + warning() << "Terminated search due to reaching max number of label outputs given input count"; + finishedSegment_ = true; return false; } @@ -298,6 +301,7 @@ bool LexiconfreeLabelsyncBeamSearch::decodeStep() { if (requests_.empty()) { // All hypotheses are terminated -> no search step can be made. + finishedSegment_ = true; return false; } @@ -318,20 +322,27 @@ bool LexiconfreeLabelsyncBeamSearch::decodeStep() { extensions_[extensionIdx].timeframe = result->timeframes[extensionIdx]; } - if (logStepwiseStatistics_) { - clog() << Core::XmlOpen("search-step-stats"); - } - /* * Prune set of possible extensions by max beam size and possibly also by score. */ + if (logStepwiseStatistics_) { + clog() << Core::XmlOpen("search-step-stats"); + } + if (useScorePruning_) { scorePruningExtensions(); + if (logStepwiseStatistics_) { + clog() << Core::XmlFull("num-extensions-after-score-pruning", extensions_.size()); + } } beamSizePruningExtensions(); + if (logStepwiseStatistics_) { + clog() << Core::XmlFull("num-extensions-after-beam-pruning", extensions_.size()); + } + /* * Create new beam from surviving extensions. */ @@ -365,17 +376,35 @@ bool LexiconfreeLabelsyncBeamSearch::decodeStep() { if (useScorePruning_) { scorePruning(); - numHypsAfterScorePruning_ += newBeam_.size(); + size_t numActive = std::accumulate( + newBeam_.begin(), + newBeam_.end(), + 0ul, + [](size_t acc, auto const& hyp) { return acc + static_cast(hyp.isActive); }); + + numTerminatedHypsAfterScorePruning_ += newBeam_.size() - numActive; + numActiveHypsAfterScorePruning_ += numActive; if (logStepwiseStatistics_) { - clog() << Core::XmlFull("num-hyps-after-score-pruning", newBeam_.size()); + clog() << Core::XmlFull("num-terminated-hyps-after-score-pruning", newBeam_.size() - numActive); + clog() << Core::XmlFull("num-active-hyps-after-score-pruning", numActive); } } beamSizePruning(); - numHypsAfterBeamPruning_ += newBeam_.size(); + + size_t numActive = std::accumulate( + newBeam_.begin(), + newBeam_.end(), + 0ul, + [](size_t acc, auto const& hyp) { return acc + static_cast(hyp.isActive); }); + + numTerminatedHypsAfterBeamPruning_ += newBeam_.size() - numActive; + numActiveHypsAfterBeamPruning_ += numActive; + if (logStepwiseStatistics_) { - clog() << Core::XmlFull("num-hyps-after-beam-pruning", newBeam_.size()); + clog() << Core::XmlFull("num-terminated-hyps-after-beam-pruning", newBeam_.size() - numActive); + clog() << Core::XmlFull("num-active-hyps-after-beam-pruning", numActive); } /* @@ -397,11 +426,11 @@ bool LexiconfreeLabelsyncBeamSearch::decodeStep() { std::stringstream ssTerminated; for (size_t hypIdx = 0ul; hypIdx < beam_.size(); ++hypIdx) { auto const& hyp = beam_[hypIdx]; - if (hyp.isActive) { - ssActive << "Active hypothesis " << hypIdx + 1ul << ": " << beam_[hypIdx].toString() << "\n"; + if (not hyp.isActive) { + ssTerminated << "Terminated hypothesis " << hypIdx + 1ul << ": " << beam_[hypIdx].toString() << "\n"; } else { - ssTerminated << "Terminated hypothesis " << hypIdx + 1ul << ": " << beam_[hypIdx].toString() << "\n"; + ssActive << "Active hypothesis " << hypIdx + 1ul << ": " << beam_[hypIdx].toString() << "\n"; } } ssActive << "\n"; @@ -410,19 +439,28 @@ bool LexiconfreeLabelsyncBeamSearch::decodeStep() { } if (logStepwiseStatistics_) { - size_t numActive = std::accumulate( - beam_.begin(), - beam_.end(), - 0ul, - [](size_t acc, auto const& hyp) { return acc + static_cast(hyp.isActive); }); - auto const& bestHyp = getBestHypothesis(); - auto const& worstHyp = getWorstHypothesis(); - clog() << Core::XmlFull("active-hyps", numActive); clog() << Core::XmlFull("terminated-hyps", beam_.size() - numActive); - clog() << Core::XmlFull("best-hyp-score", bestHyp.score); - clog() << Core::XmlFull("worst-hyp-score", worstHyp.score); - clog() << Core::XmlFull("best-hyp-normed-score", bestHyp.scaledScore); - clog() << Core::XmlFull("worst-hyp-normed-score", worstHyp.scaledScore); + clog() << Core::XmlFull("active-hyps", numActive); + auto const* bestTerminatedHyp = getBestTerminatedHypothesis(); + auto const* worstTerminatedHyp = getWorstActiveHypothesis(); + auto const* bestActiveHyp = getBestActiveHypothesis(); + auto const* worstActiveHyp = getWorstActiveHypothesis(); + if (bestTerminatedHyp != nullptr) { + clog() << Core::XmlFull("best-terminated-hyp-score", bestTerminatedHyp->score); + clog() << Core::XmlFull("best-terminated-hyp-normalized-score", bestTerminatedHyp->scaledScore); + } + if (worstTerminatedHyp != nullptr) { + clog() << Core::XmlFull("worst-terminated-hyp-score", worstTerminatedHyp->score); + clog() << Core::XmlFull("worst-terminated-hyp-normalized-score", worstTerminatedHyp->scaledScore); + } + if (bestActiveHyp != nullptr) { + clog() << Core::XmlFull("best-active-hyp-score", bestActiveHyp->score); + clog() << Core::XmlFull("best-active-hyp-normalized-score", bestActiveHyp->scaledScore); + } + if (worstActiveHyp != nullptr) { + clog() << Core::XmlFull("worst-active-hyp-score", worstActiveHyp->score); + clog() << Core::XmlFull("worst-active-hyp-normalized-score", worstActiveHyp->scaledScore); + } clog() << Core::XmlClose("search-step-stats"); } @@ -430,54 +468,80 @@ bool LexiconfreeLabelsyncBeamSearch::decodeStep() { return true; } -LexiconfreeLabelsyncBeamSearch::LabelHypothesis const& LexiconfreeLabelsyncBeamSearch::getBestHypothesis() const { - LabelHypothesis const* bestActive = nullptr; - LabelHypothesis const* bestTerminated = nullptr; +LexiconfreeLabelsyncBeamSearch::LabelHypothesis const* LexiconfreeLabelsyncBeamSearch::getBestTerminatedHypothesis() const { + LabelHypothesis const* best = nullptr; for (auto const& hyp : beam_) { - if (hyp.isActive) { - if (not bestActive or hyp < *bestActive) { - bestActive = &hyp; + if (not hyp.isActive) { + if (best == nullptr or hyp < *best) { + best = &hyp; } } - else { - if (not bestTerminated or hyp < *bestTerminated) { - bestTerminated = &hyp; + } + + return best; +} + +LexiconfreeLabelsyncBeamSearch::LabelHypothesis const* LexiconfreeLabelsyncBeamSearch::getWorstTerminatedHypothesis() const { + LabelHypothesis const* worst = nullptr; + + for (auto const& hyp : beam_) { + if (not hyp.isActive) { + if (worst == nullptr or hyp > *worst) { + worst = &hyp; } } } - if (bestTerminated) { - return *bestTerminated; - } - else { - return *bestActive; - } + return worst; } -LexiconfreeLabelsyncBeamSearch::LabelHypothesis const& LexiconfreeLabelsyncBeamSearch::getWorstHypothesis() const { - LabelHypothesis const* worstActive = nullptr; - LabelHypothesis const* worstTerminated = nullptr; +LexiconfreeLabelsyncBeamSearch::LabelHypothesis const* LexiconfreeLabelsyncBeamSearch::getBestActiveHypothesis() const { + LabelHypothesis const* best = nullptr; for (auto const& hyp : beam_) { if (hyp.isActive) { - if (not worstActive or hyp > *worstActive) { - worstActive = &hyp; + if (best == nullptr or hyp < *best) { + best = &hyp; } } - else { - if (not worstTerminated or hyp > *worstTerminated) { - worstTerminated = &hyp; + } + + return best; +} + +LexiconfreeLabelsyncBeamSearch::LabelHypothesis const* LexiconfreeLabelsyncBeamSearch::getWorstActiveHypothesis() const { + LabelHypothesis const* worst = nullptr; + + for (auto const& hyp : beam_) { + if (hyp.isActive) { + if (worst == nullptr or hyp > *worst) { + worst = &hyp; } } } - if (worstTerminated) { - return *worstTerminated; + return worst; +} + +LexiconfreeLabelsyncBeamSearch::LabelHypothesis const& LexiconfreeLabelsyncBeamSearch::getBestHypothesis() const { + auto const* result = getBestTerminatedHypothesis(); + if (result != nullptr) { + return *result; } - else { - return *worstActive; + result = getBestActiveHypothesis(); + verify(result != nullptr); + return *result; +} + +LexiconfreeLabelsyncBeamSearch::LabelHypothesis const& LexiconfreeLabelsyncBeamSearch::getWorstHypothesis() const { + auto const* result = getWorstTerminatedHypothesis(); + if (result != nullptr) { + return *result; } + result = getWorstActiveHypothesis(); + verify(result != nullptr); + return *result; } void LexiconfreeLabelsyncBeamSearch::resetStatistics() { @@ -485,8 +549,10 @@ void LexiconfreeLabelsyncBeamSearch::resetStatistics() { featureProcessingTime_.reset(); scoringTime_.reset(); contextExtensionTime_.reset(); - numHypsAfterScorePruning_.clear(); - numHypsAfterBeamPruning_.clear(); + numTerminatedHypsAfterScorePruning_.clear(); + numTerminatedHypsAfterBeamPruning_.clear(); + numActiveHypsAfterScorePruning_.clear(); + numActiveHypsAfterBeamPruning_.clear(); } void LexiconfreeLabelsyncBeamSearch::logStatistics() const { @@ -496,8 +562,10 @@ void LexiconfreeLabelsyncBeamSearch::logStatistics() const { clog() << Core::XmlOpen("scoring-time") << scoringTime_.elapsedMilliseconds() << Core::XmlClose("scoring-time"); clog() << Core::XmlOpen("context-extension-time") << contextExtensionTime_.elapsedMilliseconds() << Core::XmlClose("context-extension-time"); clog() << Core::XmlClose("timing-statistics"); - numHypsAfterScorePruning_.write(clog()); - numHypsAfterBeamPruning_.write(clog()); + numTerminatedHypsAfterScorePruning_.write(clog()); + numTerminatedHypsAfterBeamPruning_.write(clog()); + numActiveHypsAfterScorePruning_.write(clog()); + numActiveHypsAfterBeamPruning_.write(clog()); } void LexiconfreeLabelsyncBeamSearch::beamSizePruningExtensions() { @@ -505,7 +573,7 @@ void LexiconfreeLabelsyncBeamSearch::beamSizePruningExtensions() { return; } - // Reorder the hypotheses by associated score value such that the first `beamSizeActive_` elements are the best + // Reorder the hypotheses by associated score value such that the first `maxBeamSize_` elements are the best std::nth_element(extensions_.begin(), extensions_.begin() + maxBeamSize_, extensions_.end()); extensions_.resize(maxBeamSize_); // Get rid of excessive elements } @@ -515,7 +583,7 @@ void LexiconfreeLabelsyncBeamSearch::beamSizePruning() { return; } - // Reorder the hypotheses by associated score value such that the first `beamSizeTerminated_` elements are the best + // Reorder the hypotheses by associated score value such that the first `maxBeamSize_` elements are the best std::nth_element(newBeam_.begin(), newBeam_.begin() + maxBeamSize_, newBeam_.end()); newBeam_.resize(maxBeamSize_); // Get rid of excessive elements } @@ -560,7 +628,8 @@ void LexiconfreeLabelsyncBeamSearch::scorePruning() { void LexiconfreeLabelsyncBeamSearch::recombination() { recombinedHypotheses_.clear(); - // Map each unique ScoringContext in newHypotheses to its hypothesis + + // Map each unique ScoringContext in `newBeam_` to its hypothesis std::unordered_map seenScoringContexts; for (auto const& hyp : newBeam_) { // Use try_emplace to check if the scoring context already exists and create a new entry if not at the same time diff --git a/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.hh b/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.hh index 99a245ce..91b20b53 100644 --- a/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.hh +++ b/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.hh @@ -1,4 +1,4 @@ -/** Copyright 2025 RWTH Aachen University. All rights reserved. +/** Copyright 2026 RWTH Aachen University. All rights reserved. * * Licensed under the RWTH ASR License (the "License"); * you may not use this file except in compliance with the License. @@ -140,13 +140,21 @@ private: Core::StopWatch scoringTime_; Core::StopWatch contextExtensionTime_; - Core::Statistics numHypsAfterScorePruning_; - Core::Statistics numHypsAfterBeamPruning_; + Core::Statistics numTerminatedHypsAfterScorePruning_; + Core::Statistics numTerminatedHypsAfterBeamPruning_; + Core::Statistics numActiveHypsAfterScorePruning_; + Core::Statistics numActiveHypsAfterBeamPruning_; size_t currentSearchStep_; size_t totalTimesteps_; bool finishedSegment_; + LabelHypothesis const* getBestTerminatedHypothesis() const; + LabelHypothesis const* getWorstTerminatedHypothesis() const; + + LabelHypothesis const* getBestActiveHypothesis() const; + LabelHypothesis const* getWorstActiveHypothesis() const; + LabelHypothesis const& getBestHypothesis() const; LabelHypothesis const& getWorstHypothesis() const; diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc index b08840f3..d2568a8a 100644 --- a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc @@ -68,6 +68,8 @@ LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( trace->score.acoustic = extension.score; trace->time = extension.timeframe + 1; break; + default: + defect(); // Unexpected transition type which can not be produced by `inferTransitionType` } } diff --git a/src/Search/Module.cc b/src/Search/Module.cc index 49d9bea8..2d87352b 100644 --- a/src/Search/Module.cc +++ b/src/Search/Module.cc @@ -15,6 +15,7 @@ #include #include #include +#include "LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.hh" #include "LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh" #include "TreeBuilder.hh" #ifdef MODULE_SEARCH_WFST @@ -35,6 +36,7 @@ Module_::Module_() { const Core::Choice Module_::searchTypeV2Choice( "lexiconfree-timesync-beam-search", SearchTypeV2::LexiconfreeTimesyncBeamSearchType, + "lexiconfree-labelsync-beam-search", SearchTypeV2::LexiconfreeLabelsyncBeamSearchType, Core::Choice::endMark()); const Core::ParameterChoice Module_::searchTypeV2Param( @@ -110,6 +112,9 @@ SearchAlgorithmV2* Module_::createSearchAlgorithmV2(const Core::Configuration& c case LexiconfreeTimesyncBeamSearchType: searchAlgorithm = new Search::LexiconfreeTimesyncBeamSearch(config); break; + case LexiconfreeLabelsyncBeamSearchType: + searchAlgorithm = new Search::LexiconfreeLabelsyncBeamSearch(config); + break; default: Core::Application::us()->criticalError("Unknown search algorithm type: %d", searchTypeV2Param(config)); break; diff --git a/src/Search/Module.hh b/src/Search/Module.hh index 4efed485..16d89913 100644 --- a/src/Search/Module.hh +++ b/src/Search/Module.hh @@ -17,6 +17,7 @@ #include #include +#include "LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.hh" #include "SearchV2.hh" #include "TreeBuilder.hh" @@ -41,7 +42,8 @@ enum SearchType { }; enum SearchTypeV2 { - LexiconfreeTimesyncBeamSearchType + LexiconfreeTimesyncBeamSearchType, + LexiconfreeLabelsyncBeamSearchType }; class Module_ { From fad9928e99ff9016225eef370e6d4ce25b221004 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Wed, 28 May 2025 17:58:32 +0200 Subject: [PATCH 47/52] Add missing include --- .../LexiconfreeLabelsyncBeamSearch.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc b/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc index 813b6f24..e613d2a2 100644 --- a/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc +++ b/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc @@ -16,6 +16,7 @@ #include "LexiconfreeLabelsyncBeamSearch.hh" #include +#include #include #include From b756a822d39298d93f82bfd4495310c3572b4b07 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Fri, 30 May 2025 15:54:40 +0200 Subject: [PATCH 48/52] Fix typo in year --- .../LexiconfreeLabelsyncBeamSearch.hh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.hh b/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.hh index 91b20b53..7b6c8d92 100644 --- a/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.hh +++ b/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.hh @@ -1,4 +1,4 @@ -/** Copyright 2026 RWTH Aachen University. All rights reserved. +/** Copyright 2025 RWTH Aachen University. All rights reserved. * * Licensed under the RWTH ASR License (the "License"); * you may not use this file except in compliance with the License. From 311eb8f7d4a34ddccb9f5191be80c3a52f95ce2d Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Mon, 28 Jul 2025 16:22:54 +0200 Subject: [PATCH 49/52] Fix pruning order and logging --- .../LexiconfreeLabelsyncBeamSearch.cc | 118 +++++++++--------- .../LexiconfreeLabelsyncBeamSearch.hh | 69 +++++----- .../LexiconfreeTimesyncBeamSearch.cc | 5 - 3 files changed, 95 insertions(+), 97 deletions(-) diff --git a/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc b/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc index e613d2a2..24370ce5 100644 --- a/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc +++ b/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc @@ -22,7 +22,6 @@ #include #include #include -#include #include namespace Search { @@ -83,7 +82,7 @@ std::string LexiconfreeLabelsyncBeamSearch::LabelHypothesis::toString() const { const Core::ParameterInt LexiconfreeLabelsyncBeamSearch::paramMaxBeamSize( "max-beam-size", - "Maximum number of hypotheses in the search beam.", + "Maximum number of elements in the search beam.", 1, 1); const Core::ParameterFloat LexiconfreeLabelsyncBeamSearch::paramScoreThreshold( @@ -113,6 +112,11 @@ const Core::ParameterBool LexiconfreeLabelsyncBeamSearch::paramLogStepwiseStatis "Log statistics about the beam at every search step.", false); +const Core::ParameterBool LexiconfreeLabelsyncBeamSearch::paramCacheCleanupInterval( + "cache-cleanup-interval", + "Interval of search steps after which buffered inputs that are not needed anymore get cleaned up.", + 10); + LexiconfreeLabelsyncBeamSearch::LexiconfreeLabelsyncBeamSearch(Core::Configuration const& config) : Core::Component(config), SearchAlgorithmV2(config), @@ -122,6 +126,7 @@ LexiconfreeLabelsyncBeamSearch::LexiconfreeLabelsyncBeamSearch(Core::Configurati maxLabelsPerTimestep_(paramMaxLabelsPerTimestep(config)), sentenceEndLabelIndex_(paramSentenceEndLabelIndex(config)), logStepwiseStatistics_(paramLogStepwiseStatistics(config)), + cacheCleanupInterval_(paramCacheCleanupInterval(config)), debugChannel_(config, "debug"), labelScorer_(), beam_(), @@ -134,19 +139,19 @@ LexiconfreeLabelsyncBeamSearch::LexiconfreeLabelsyncBeamSearch(Core::Configurati scoringTime_(), contextExtensionTime_(), numTerminatedHypsAfterScorePruning_("num-termianted-hyps-after-score-pruning"), + numTerminatedHypsAfterRecombination_("num-terminated-hyps-after-recombination"), numTerminatedHypsAfterBeamPruning_("num-terminated-hyps-after-beam-pruning"), numActiveHypsAfterScorePruning_("num-active-hyps-after-score-pruning"), + numActiveHypsAfterRecombination_("num-active-hyps-after-recombination"), numActiveHypsAfterBeamPruning_("num-active-hyps-after-beam-pruning"), currentSearchStep_(0ul), totalTimesteps_(0ul), finishedSegment_(false) { - beam_.reserve(maxBeamSize_); - newBeam_.reserve(maxBeamSize_ * 2); // terminated + active - recombinedHypotheses_.reserve(maxBeamSize_ * 2); // terminated + active - useScorePruning_ = scoreThreshold_ != Core::Type::max; - log() << "Use sentence-end label with index " << sentenceEndLabelIndex_; + if (sentenceEndLabelIndex_ != Core::Type::max) { + log() << "Use sentence-end label with index " << sentenceEndLabelIndex_; + } } Speech::ModelCombination::Mode LexiconfreeLabelsyncBeamSearch::requiredModelCombination() const { @@ -157,15 +162,12 @@ bool LexiconfreeLabelsyncBeamSearch::setModelCombination(Speech::ModelCombinatio lexicon_ = modelCombination.lexicon(); labelScorer_ = modelCombination.labelScorer(); - extensions_.reserve(maxBeamSize_ * lexicon_->nLemmas()); - requests_.reserve(extensions_.size()); - auto sentenceEndLemma = lexicon_->specialLemma("sentence-end"); if (!sentenceEndLemma) { sentenceEndLemma = lexicon_->specialLemma("sentence-boundary"); } if (sentenceEndLemma) { - if (sentenceEndLabelIndex_ == Core::Type::max) { + if (sentenceEndLabelIndex_ == Core::Type::max) { sentenceEndLabelIndex_ = sentenceEndLemma->id(); log() << "Use sentence-end index " << sentenceEndLabelIndex_ << " inferred from lexicon"; } @@ -323,14 +325,13 @@ bool LexiconfreeLabelsyncBeamSearch::decodeStep() { extensions_[extensionIdx].timeframe = result->timeframes[extensionIdx]; } - /* - * Prune set of possible extensions by max beam size and possibly also by score. - */ - if (logStepwiseStatistics_) { clog() << Core::XmlOpen("search-step-stats"); } + /* + * Maybe prune set of possible extensions by score. + */ if (useScorePruning_) { scorePruningExtensions(); if (logStepwiseStatistics_) { @@ -338,12 +339,6 @@ bool LexiconfreeLabelsyncBeamSearch::decodeStep() { } } - beamSizePruningExtensions(); - - if (logStepwiseStatistics_) { - clog() << Core::XmlFull("num-extensions-after-beam-pruning", extensions_.size()); - } - /* * Create new beam from surviving extensions. */ @@ -366,61 +361,69 @@ bool LexiconfreeLabelsyncBeamSearch::decodeStep() { } /* - * For all hypotheses with the same scoring context keep only the best since they will - * all develop in the same way. - */ - recombination(); - - /* - * Jointly prune terminated and active hypotheses + * Jointly prune terminated and active hypotheses by score */ if (useScorePruning_) { scorePruning(); - size_t numActive = std::accumulate( - newBeam_.begin(), - newBeam_.end(), - 0ul, - [](size_t acc, auto const& hyp) { return acc + static_cast(hyp.isActive); }); + size_t numActive = numActiveHyps(); + size_t numTerminated = newBeam_.size() - numActive; - numTerminatedHypsAfterScorePruning_ += newBeam_.size() - numActive; + numTerminatedHypsAfterScorePruning_ += numTerminated; numActiveHypsAfterScorePruning_ += numActive; if (logStepwiseStatistics_) { - clog() << Core::XmlFull("num-terminated-hyps-after-score-pruning", newBeam_.size() - numActive); + clog() << Core::XmlFull("num-terminated-hyps-after-score-pruning", numTerminated); clog() << Core::XmlFull("num-active-hyps-after-score-pruning", numActive); } } + /* + * For all hypotheses with the same scoring context keep only the best since they will + * all develop in the same way. + */ + recombination(); + + size_t numActive = numActiveHyps(); + size_t numTerminated = newBeam_.size() - numActive; + + numTerminatedHypsAfterRecombination_ += numTerminated; + numActiveHypsAfterRecombination_ += numActive; + + if (logStepwiseStatistics_) { + clog() << Core::XmlFull("num-terminated-hyps-after-recombination", numTerminated); + clog() << Core::XmlFull("num-active-hyps-after-recombination", numActive); + } + beamSizePruning(); - size_t numActive = std::accumulate( - newBeam_.begin(), - newBeam_.end(), - 0ul, - [](size_t acc, auto const& hyp) { return acc + static_cast(hyp.isActive); }); + numActive = numActiveHyps(); + numTerminated = newBeam_.size() - numActive; - numTerminatedHypsAfterBeamPruning_ += newBeam_.size() - numActive; + numTerminatedHypsAfterBeamPruning_ += numTerminated; numActiveHypsAfterBeamPruning_ += numActive; if (logStepwiseStatistics_) { - clog() << Core::XmlFull("num-terminated-hyps-after-beam-pruning", newBeam_.size() - numActive); + clog() << Core::XmlFull("num-terminated-hyps-after-beam-pruning", numTerminated); clog() << Core::XmlFull("num-active-hyps-after-beam-pruning", numActive); } /* * Clean up label scorer caches. */ - Core::CollapsedVector activeContexts; - for (auto const& hyp : newBeam_) { - activeContexts.push_back(hyp.scoringContext); + if (++currentSearchStep_ % cacheCleanupInterval_ == 0) { + Core::CollapsedVector activeContexts; + for (auto const& hyp : newBeam_) { + activeContexts.push_back(hyp.scoringContext); + } + labelScorer_->cleanupCaches(activeContexts); } - labelScorer_->cleanupCaches(activeContexts); + + beam_.swap(newBeam_); /* * Log statistics about the new beam after this step. */ - beam_.swap(newBeam_); if (debugChannel_.isOpen()) { std::stringstream ssActive; @@ -440,8 +443,6 @@ bool LexiconfreeLabelsyncBeamSearch::decodeStep() { } if (logStepwiseStatistics_) { - clog() << Core::XmlFull("terminated-hyps", beam_.size() - numActive); - clog() << Core::XmlFull("active-hyps", numActive); auto const* bestTerminatedHyp = getBestTerminatedHypothesis(); auto const* worstTerminatedHyp = getWorstActiveHypothesis(); auto const* bestActiveHyp = getBestActiveHypothesis(); @@ -465,7 +466,6 @@ bool LexiconfreeLabelsyncBeamSearch::decodeStep() { clog() << Core::XmlClose("search-step-stats"); } - ++currentSearchStep_; return true; } @@ -569,16 +569,6 @@ void LexiconfreeLabelsyncBeamSearch::logStatistics() const { numActiveHypsAfterBeamPruning_.write(clog()); } -void LexiconfreeLabelsyncBeamSearch::beamSizePruningExtensions() { - if (extensions_.size() <= maxBeamSize_) { - return; - } - - // Reorder the hypotheses by associated score value such that the first `maxBeamSize_` elements are the best - std::nth_element(extensions_.begin(), extensions_.begin() + maxBeamSize_, extensions_.end()); - extensions_.resize(maxBeamSize_); // Get rid of excessive elements -} - void LexiconfreeLabelsyncBeamSearch::beamSizePruning() { if (newBeam_.size() <= maxBeamSize_) { return; @@ -661,4 +651,12 @@ void LexiconfreeLabelsyncBeamSearch::recombination() { newBeam_.swap(recombinedHypotheses_); } +size_t LexiconfreeLabelsyncBeamSearch::numActiveHyps() const { + return std::accumulate( + newBeam_.begin(), + newBeam_.end(), + 0ul, + [](size_t acc, auto const& hyp) { return acc + static_cast(hyp.isActive); }); +} + } // namespace Search diff --git a/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.hh b/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.hh index 7b6c8d92..863c220c 100644 --- a/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.hh +++ b/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.hh @@ -39,6 +39,31 @@ namespace Search { * in the lexicon corresponding to the associated output index of the label scorer. */ class LexiconfreeLabelsyncBeamSearch : public SearchAlgorithmV2 { +public: + static const Core::ParameterInt paramMaxBeamSize; + static const Core::ParameterFloat paramScoreThreshold; + + static const Core::ParameterInt paramSentenceEndLabelIndex; + static const Core::ParameterBool paramCacheCleanupInterval; + static const Core::ParameterFloat paramLengthNormScale; + static const Core::ParameterFloat paramMaxLabelsPerTimestep; + static const Core::ParameterBool paramLogStepwiseStatistics; + + LexiconfreeLabelsyncBeamSearch(Core::Configuration const&); + + // Inherited methods from `SearchAlgorithmV2` + + Speech::ModelCombination::Mode requiredModelCombination() const override; + bool setModelCombination(Speech::ModelCombination const& modelCombination) override; + void reset() override; + void enterSegment(Bliss::SpeechSegment const* = nullptr) override; + void finishSegment() override; + void putFeature(Nn::DataView const& feature) override; + void putFeatures(Nn::DataView const& features, size_t nTimesteps) override; + Core::Ref getCurrentBestTraceback() const override; + Core::Ref getCurrentBestWordLattice() const override; + bool decodeStep() override; + protected: /* * Possible extension for some label hypothesis in the beam @@ -85,30 +110,6 @@ protected: std::string toString() const; }; -public: - static const Core::ParameterInt paramMaxBeamSize; - static const Core::ParameterFloat paramScoreThreshold; - - static const Core::ParameterInt paramSentenceEndLabelIndex; - static const Core::ParameterFloat paramLengthNormScale; - static const Core::ParameterFloat paramMaxLabelsPerTimestep; - static const Core::ParameterBool paramLogStepwiseStatistics; - - LexiconfreeLabelsyncBeamSearch(Core::Configuration const&); - - // Inherited methods from `SearchAlgorithmV2` - - Speech::ModelCombination::Mode requiredModelCombination() const override; - bool setModelCombination(Speech::ModelCombination const& modelCombination) override; - void reset() override; - void enterSegment(Bliss::SpeechSegment const* = nullptr) override; - void finishSegment() override; - void putFeature(Nn::DataView const& feature) override; - void putFeatures(Nn::DataView const& features, size_t nTimesteps) override; - Core::Ref getCurrentBestTraceback() const override; - Core::Ref getCurrentBestWordLattice() const override; - bool decodeStep() override; - private: size_t maxBeamSize_; @@ -123,6 +124,8 @@ private: bool logStepwiseStatistics_; + size_t cacheCleanupInterval_; + Core::Channel debugChannel_; Core::Ref labelScorer_; @@ -141,8 +144,10 @@ private: Core::StopWatch contextExtensionTime_; Core::Statistics numTerminatedHypsAfterScorePruning_; + Core::Statistics numTerminatedHypsAfterRecombination_; Core::Statistics numTerminatedHypsAfterBeamPruning_; Core::Statistics numActiveHypsAfterScorePruning_; + Core::Statistics numActiveHypsAfterRecombination_; Core::Statistics numActiveHypsAfterBeamPruning_; size_t currentSearchStep_; @@ -162,22 +167,17 @@ private: void logStatistics() const; /* - * Helper function for pruning of extensions to maxBeamSize_ - */ - void beamSizePruningExtensions(); - - /* - * Helper function for pruning of hyps to maxBeamSize_ + * Helper function for pruning of hyps to `maxBeamSize_` */ void beamSizePruning(); /* - * Helper function for pruning of extensions to scoreThreshold__ + * Helper function for pruning of extensions to `scoreThreshold_` */ void scorePruningExtensions(); /* - * Helper function for pruning of hyps to scoreThreshold_ + * Helper function for pruning of hyps to `scoreThreshold_` */ void scorePruning(); @@ -185,6 +185,11 @@ private: * Helper function for recombination of hypotheses with the same scoring context */ void recombination(); + + /* + * Count hyps with `isActive` flag in `newBeam_` + */ + size_t numActiveHyps() const; }; } // namespace Search diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc index 87be609b..11766979 100644 --- a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc @@ -365,11 +365,6 @@ bool LexiconfreeTimesyncBeamSearch::decodeStep() { */ beam_.swap(newBeam_); - /* - * Log statistics about the new beam after this step. - */ - beam_.swap(newBeam_); - if (debugChannel_.isOpen()) { std::stringstream ss; for (size_t hypIdx = 0ul; hypIdx < beam_.size(); ++hypIdx) { From a71e729d675635d41f828b9413a2b7630be67569 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Mon, 28 Jul 2025 16:27:45 +0200 Subject: [PATCH 50/52] Remove unnecessary include --- src/Search/Module.hh | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Search/Module.hh b/src/Search/Module.hh index 7d225014..39dcc819 100644 --- a/src/Search/Module.hh +++ b/src/Search/Module.hh @@ -17,7 +17,6 @@ #include #include -#include "LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.hh" #include "SearchV2.hh" #include "TreeBuilder.hh" From 8fc3d7a6255e6e5b9a1995ca9b2c872d5e614250 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Mon, 28 Jul 2025 16:33:30 +0200 Subject: [PATCH 51/52] Some synchronization with updates from master --- src/Core/CollapsedVector.hh | 26 ----------------------- src/Nn/LabelScorer/BufferedLabelScorer.hh | 1 - 2 files changed, 27 deletions(-) diff --git a/src/Core/CollapsedVector.hh b/src/Core/CollapsedVector.hh index 73fcea48..e6b4492c 100644 --- a/src/Core/CollapsedVector.hh +++ b/src/Core/CollapsedVector.hh @@ -52,12 +52,6 @@ public: inline size_t internalSize() const; inline std::vector const& internalData() const; - inline typename std::vector::iterator begin(); - inline typename std::vector::iterator end(); - - inline typename std::vector::const_iterator begin() const; - inline typename std::vector::const_iterator end() const; - private: std::vector data_; size_t logicalSize_; @@ -141,26 +135,6 @@ inline std::vector const& CollapsedVector::internalData() const { return data_; } -template -inline typename std::vector::iterator CollapsedVector::begin() { - return data_.begin(); -} - -template -inline typename std::vector::iterator CollapsedVector::end() { - return data_.end(); -} - -template -inline typename std::vector::const_iterator CollapsedVector::begin() const { - return data_.begin(); -} - -template -inline typename std::vector::const_iterator CollapsedVector::end() const { - return data_.end(); -} - } // namespace Core #endif // COLLAPSED_VECTOR_HH diff --git a/src/Nn/LabelScorer/BufferedLabelScorer.hh b/src/Nn/LabelScorer/BufferedLabelScorer.hh index e6ae3454..6ce005db 100644 --- a/src/Nn/LabelScorer/BufferedLabelScorer.hh +++ b/src/Nn/LabelScorer/BufferedLabelScorer.hh @@ -19,7 +19,6 @@ #include #include "DataView.hh" #include "LabelScorer.hh" -#include "Speech/Types.hh" namespace Nn { From 3647314e93d2fb350274b815a5cf5a912ca51621 Mon Sep 17 00:00:00 2001 From: SimBe195 <37951951+SimBe195@users.noreply.github.com> Date: Tue, 29 Jul 2025 12:04:47 +0200 Subject: [PATCH 52/52] Use `Nn::invalidLabelIndex` in `LabelHypothesi` constructor Co-authored-by: larissakl <56513675+larissakl@users.noreply.github.com> --- .../LexiconfreeLabelsyncBeamSearch.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc b/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc index 24370ce5..48fd5344 100644 --- a/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc +++ b/src/Search/LexiconfreeLabelsyncBeamSearch/LexiconfreeLabelsyncBeamSearch.cc @@ -34,7 +34,7 @@ namespace Search { LexiconfreeLabelsyncBeamSearch::LabelHypothesis::LabelHypothesis() : scoringContext(), - currentToken(Core::Type::max), + currentToken(Nn::invalidLabelIndex), length(0), score(0.0), scaledScore(0.0),