diff --git a/src/Nn/LabelScorer/Makefile b/src/Nn/LabelScorer/Makefile index 22d177572..fb419b22b 100644 --- a/src/Nn/LabelScorer/Makefile +++ b/src/Nn/LabelScorer/Makefile @@ -22,7 +22,8 @@ LIBSPRINTLABELSCORER_O = \ $(OBJDIR)/NoContextOnnxLabelScorer.o \ $(OBJDIR)/NoOpLabelScorer.o \ $(OBJDIR)/ScoringContext.o \ - $(OBJDIR)/StatefulOnnxLabelScorer.o + $(OBJDIR)/StatefulOnnxLabelScorer.o \ + $(OBJDIR)/StatefulTransducerOnnxLabelScorer.o # ----------------------------------------------------------------------------- diff --git a/src/Nn/LabelScorer/ScoringContext.cc b/src/Nn/LabelScorer/ScoringContext.cc index aa3fa8eef..4bbdc3dcd 100644 --- a/src/Nn/LabelScorer/ScoringContext.cc +++ b/src/Nn/LabelScorer/ScoringContext.cc @@ -136,4 +136,36 @@ bool OnnxHiddenStateScoringContext::isEqual(ScoringContextRef const& other) cons return true; } +/* + * ===================================== + * = StepOnnxHiddenStateScoringContext = + * ===================================== + */ +size_t StepOnnxHiddenStateScoringContext::hash() const { + return Core::combineHashes(currentStep, Core::MurmurHash3_x64_64(reinterpret_cast(labelSeq.data()), labelSeq.size() * sizeof(LabelIndex), 0x78b174eb)); +} + +bool StepOnnxHiddenStateScoringContext::isEqual(ScoringContextRef const& other) const { + auto* otherPtr = dynamic_cast(other.get()); + if (otherPtr == nullptr) { + return false; + } + + if (currentStep != otherPtr->currentStep) { + return false; + } + + if (labelSeq.size() != otherPtr->labelSeq.size()) { + return false; + } + + for (auto it_l = labelSeq.begin(), it_r = otherPtr->labelSeq.begin(); it_l != labelSeq.end(); ++it_l, ++it_r) { + if (*it_l != *it_r) { + return false; + } + } + + return true; +} + } // namespace Nn diff --git a/src/Nn/LabelScorer/ScoringContext.hh b/src/Nn/LabelScorer/ScoringContext.hh index af2adf10f..8404ae8bb 100644 --- a/src/Nn/LabelScorer/ScoringContext.hh +++ b/src/Nn/LabelScorer/ScoringContext.hh @@ -148,6 +148,29 @@ struct OnnxHiddenStateScoringContext : public ScoringContext { typedef Core::Ref OnnxHiddenStateScoringContextRef; +/* + * Scoring context consisting of a hidden state and a step. + * Assumes that two hidden states are equal if and only if they were created + * from the same label history. + */ +struct StepOnnxHiddenStateScoringContext : public ScoringContext { + Speech::TimeframeIndex currentStep; + std::vector labelSeq; // Used for hashing + mutable OnnxHiddenStateRef hiddenState; + mutable bool requiresFinalize; + + StepOnnxHiddenStateScoringContext() + : currentStep(0u), labelSeq(), hiddenState(), requiresFinalize(false) {} + + StepOnnxHiddenStateScoringContext(Speech::TimeframeIndex step, std::vector const& labelSeq, OnnxHiddenStateRef state) + : currentStep(step), labelSeq(labelSeq), hiddenState(state), requiresFinalize(false) {} + + bool isEqual(ScoringContextRef const& other) const; + size_t hash() const; +}; + +typedef Core::Ref StepOnnxHiddenStateScoringContextRef; + } // namespace Nn #endif // SCORING_CONTEXT_HH diff --git a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.hh b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.hh index 0f343ac51..ca518898d 100644 --- a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.hh +++ b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.hh @@ -56,6 +56,10 @@ namespace Nn { * * A common use case for this Label Scorer would be an AED model with cross-attention over the encoder output. * Since the encoder state inputs are optional, it can also be used for stateful language models without acoustic input. + * + * Note: This LabelScorer is similar to the `StatefulTransducerOnnxLabelScorer`. The difference is that in this it is assumed that the + * input features are processed into the hidden states and they are not directly fed into the scorer. For this, the state initializer + * and updater here also take input features in addition to tokens. */ class StatefulOnnxLabelScorer : public BufferedLabelScorer { using Precursor = BufferedLabelScorer; diff --git a/src/Nn/LabelScorer/StatefulTransducerOnnxLabelScorer.cc b/src/Nn/LabelScorer/StatefulTransducerOnnxLabelScorer.cc new file mode 100644 index 000000000..035abc22d --- /dev/null +++ b/src/Nn/LabelScorer/StatefulTransducerOnnxLabelScorer.cc @@ -0,0 +1,425 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "StatefulTransducerOnnxLabelScorer.hh" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "LabelScorer.hh" +#include "ScoringContext.hh" + +namespace Nn { + +/* + * ============================= + * == StatefulTransducerOnnxLabelScorer == + * ============================= + */ + +const Core::ParameterBool StatefulTransducerOnnxLabelScorer::paramBlankUpdatesHistory( + "blank-updates-history", + "Whether previously emitted blank labels should be used to update the history.", + false); + +const Core::ParameterBool StatefulTransducerOnnxLabelScorer::paramLoopUpdatesHistory( + "loop-updates-history", + "Whether in the case of loop transitions every repeated emission should be used to update the history.", + false); + +const Core::ParameterBool StatefulTransducerOnnxLabelScorer::paramVerticalLabelTransition( + "vertical-label-transition", + "Whether (non-blank) label transitions should be vertical, i.e. not increase the time step.", + false); + +const Core::ParameterInt StatefulTransducerOnnxLabelScorer::paramMaxBatchSize( + "max-batch-size", + "Max number of hidden-states that can be fed into the scorer ONNX model at once.", + Core::Type::max); + +const Core::ParameterInt StatefulTransducerOnnxLabelScorer::paramMaxCachedScores( + "max-cached-score-vectors", + "Maximum size of cache that maps histories to scores. This prevents memory overflow in case of very long audio segments.", + 1000); + +// Scorer only takes hidden states as input which are not part of the IO spec +const std::vector scorerModelIoSpec = { + Onnx::IOSpecification{ + "input-feature", + Onnx::IODirection::INPUT, + false, + {Onnx::ValueType::TENSOR}, + {Onnx::ValueDataType::FLOAT}, + {{-1, -2}, {1, -2}}}, // [1, E] + Onnx::IOSpecification{ + "scores", + Onnx::IODirection::OUTPUT, + false, + {Onnx::ValueType::TENSOR}, + {Onnx::ValueDataType::FLOAT}, + {{-1, -2}}}}; // [B, V] + +const std::vector stateUpdaterModelIoSpec = { + Onnx::IOSpecification{ + "token", + Onnx::IODirection::INPUT, + false, + {Onnx::ValueType::TENSOR}, + {Onnx::ValueDataType::INT32}, + {{1}, {-1}}}}; // [1] or [B] + +StatefulTransducerOnnxLabelScorer::StatefulTransducerOnnxLabelScorer(Core::Configuration const& config) + : Core::Component(config), + Precursor(config), + blankUpdatesHistory_(paramBlankUpdatesHistory(config)), + loopUpdatesHistory_(paramLoopUpdatesHistory(config)), + verticalLabelTransition_(paramVerticalLabelTransition(config)), + maxBatchSize_(paramMaxBatchSize(config)), + scorerOnnxModel_(select("scorer-model"), scorerModelIoSpec), + stateInitializerOnnxModel_(select("state-initializer-model"), {}), + stateUpdaterOnnxModel_(select("state-updater-model"), stateUpdaterModelIoSpec), + initialScoringContext_(), + initializerOutputToStateNameMap_(), + updaterInputToStateNameMap_(), + updaterOutputToStateNameMap_(), + scorerInputToStateNameMap_(), + scorerInputFeatureName_(scorerOnnxModel_.mapping.getOnnxName("input-feature")), + scorerScoresName_(scorerOnnxModel_.mapping.getOnnxName("scores")), + updaterTokenName_(stateUpdaterOnnxModel_.mapping.getOnnxName("token")), + scoreCache_(paramMaxCachedScores(config)) { + auto initializerMetadataKeys = stateInitializerOnnxModel_.session.getCustomMetadataKeys(); + auto updaterMetadataKeys = stateUpdaterOnnxModel_.session.getCustomMetadataKeys(); + auto scorerMetadataKeys = scorerOnnxModel_.session.getCustomMetadataKeys(); + + // Map state initializer outputs to states + std::unordered_set initializerStateNames; + for (auto const& key : initializerMetadataKeys) { + if (stateInitializerOnnxModel_.session.hasOutput(key)) { + auto stateName = stateInitializerOnnxModel_.session.getCustomMetadata(key); + initializerOutputToStateNameMap_.emplace(key, stateName); + initializerStateNames.insert(stateName); + } + } + if (initializerStateNames.empty()) { + error() << "State initializer does not define any hidden states."; + } + + // Map state updater inputs and outputs to states + std::unordered_set updaterStateNames; + for (auto const& key : updaterMetadataKeys) { + if (stateUpdaterOnnxModel_.session.hasInput(key)) { + auto stateName = stateUpdaterOnnxModel_.session.getCustomMetadata(key); + if (initializerStateNames.find(stateName) == initializerStateNames.end()) { + error() << "State updater input " << key << " associated with state " << stateName << " is not present in state initializer"; + } + updaterInputToStateNameMap_.emplace(key, stateName); + } + if (stateUpdaterOnnxModel_.session.hasOutput(key)) { + auto stateName = stateUpdaterOnnxModel_.session.getCustomMetadata(key); + if (initializerStateNames.find(stateName) == initializerStateNames.end()) { + error() << "State updater output " << key << " associated with state " << stateName << " is not present in state initializer"; + } + updaterOutputToStateNameMap_.emplace(key, stateName); + updaterStateNames.insert(stateName); + } + } + if (updaterOutputToStateNameMap_.empty()) { + error() << "State updater does not produce any updated hidden states"; + } + + // In the loop we checked that the updater outputs are a subset of the initializer outputs. + // If they have the same size, they are equal. Otherwise, some initializer outputs + // are not updater outputs. + if (initializerStateNames.size() != updaterStateNames.size()) { + warning() << "State initializer has states that are not updated by the state updater"; + } + + // Map scorer inputs to states + for (auto const& key : scorerMetadataKeys) { + if (scorerOnnxModel_.session.hasInput(key)) { + auto stateName = scorerOnnxModel_.session.getCustomMetadata(key); + if (initializerStateNames.find(stateName) == initializerStateNames.end()) { + error() << "Scorer input " << key << " associated with state " << stateName << " is not present in state initializer"; + } + scorerInputToStateNameMap_.emplace(key, stateName); + } + } + if (scorerInputToStateNameMap_.empty()) { + error() << "Scorer does not take any input hidden-states"; + } +} + +void StatefulTransducerOnnxLabelScorer::reset() { + Precursor::reset(); + scoreCache_.clear(); +} + +Core::Ref StatefulTransducerOnnxLabelScorer::getInitialScoringContext() { + if (not initialScoringContext_) { + std::vector sessionOutputNames; + std::vector stateNames; + for (auto const& [outputName, stateName] : initializerOutputToStateNameMap_) { + sessionOutputNames.push_back(outputName); + stateNames.push_back(stateName); + } + + std::vector sessionOutputs; + stateInitializerOnnxModel_.session.run({}, sessionOutputNames, sessionOutputs); + + auto initialHiddenState = Core::ref(new OnnxHiddenState(std::move(stateNames), std::move(sessionOutputs))); + initialScoringContext_ = Core::ref(new StepOnnxHiddenStateScoringContext(0ul, std::vector(), initialHiddenState)); + } + + return initialScoringContext_; +} + +Core::Ref StatefulTransducerOnnxLabelScorer::extendedScoringContext(LabelScorer::Request const& request) { + StepOnnxHiddenStateScoringContextRef scoringContext(dynamic_cast(request.context.get())); + + bool pushToken = false; + size_t timeIncrement = 0ul; + switch (request.transitionType) { + case LabelScorer::TransitionType::BLANK_LOOP: + pushToken = blankUpdatesHistory_ and loopUpdatesHistory_; + timeIncrement = 1ul; + break; + case LabelScorer::TransitionType::LABEL_TO_BLANK: + case LabelScorer::TransitionType::INITIAL_BLANK: + pushToken = blankUpdatesHistory_; + timeIncrement = 1ul; + break; + case LabelScorer::TransitionType::LABEL_LOOP: + pushToken = loopUpdatesHistory_; + timeIncrement = not verticalLabelTransition_; + break; + case LabelScorer::TransitionType::BLANK_TO_LABEL: + case LabelScorer::TransitionType::LABEL_TO_LABEL: + case LabelScorer::TransitionType::INITIAL_LABEL: + pushToken = true; + timeIncrement = not verticalLabelTransition_; + break; + default: + error() << "Unknown transition type " << request.transitionType; + } + + // If scoringContext is not going to be modified, return the original one + if (not pushToken and timeIncrement == 0ul) { + return request.context; + } + + std::vector newLabelSeq(scoringContext->labelSeq); + bool requiresFinalize = false; + + if (pushToken) { + newLabelSeq.push_back(request.nextToken); + requiresFinalize = true; + } + + // Re-use previous hidden-state but mark that finalization (i.e. hidden-state update) is required + auto newScoringContext = Core::ref(new StepOnnxHiddenStateScoringContext(scoringContext->currentStep + timeIncrement, std::move(newLabelSeq), scoringContext->hiddenState)); + newScoringContext->requiresFinalize = requiresFinalize; + + return newScoringContext; +} + +std::optional StatefulTransducerOnnxLabelScorer::computeScoresWithTimes(std::vector const& requests) { + if (requests.empty()) { + return ScoresWithTimes(); + } + + ScoresWithTimes result; + result.scores.reserve(requests.size()); + + /* + * Collect all requests that are based on the same timestep (-> same input feature) and + * group them together + */ + std::unordered_map> requestsWithTimestep; // Maps timestep to list of all indices of requests with that timestep + + for (size_t b = 0ul; b < requests.size(); ++b) { + StepOnnxHiddenStateScoringContextRef context(dynamic_cast(requests[b].context.get())); + finalizeScoringContext(context); + auto step = context->currentStep; + + if (not getInput(step)) { + // Early exit if at least one of the histories is not scorable yet + return {}; + } + result.timeframes.push_back(step); + + // Create new vector if step value isn't present in map yet + auto [it, inserted] = requestsWithTimestep.emplace(step, std::vector()); + it->second.push_back(b); + } + + /* + * Iterate over distinct timesteps + */ + for (auto const& [timestep, requestIndices] : requestsWithTimestep) { + /* + * Identify unique histories that still need session runs + */ + std::unordered_set uniqueUncachedHistories; + + for (auto requestIndex : requestIndices) { + StepOnnxHiddenStateScoringContextRef scoringContextRef(dynamic_cast(requests[requestIndex].context.get())); + if (not scoreCache_.contains(scoringContextRef)) { + // Group by unique scoringContext + uniqueUncachedHistories.emplace(scoringContextRef); + } + } + + if (uniqueUncachedHistories.empty()) { + continue; + } + + std::vector scoringContextBatch; + scoringContextBatch.reserve(std::min(uniqueUncachedHistories.size(), maxBatchSize_)); + for (auto scoringContext : uniqueUncachedHistories) { + scoringContextBatch.push_back(scoringContext); + if (scoringContextBatch.size() == maxBatchSize_) { // Batch is full -> forward now + forwardBatch(scoringContextBatch); + scoringContextBatch.clear(); + } + } + + forwardBatch(scoringContextBatch); // Forward remaining histories + } + + /* + * Assign from cache map to result vector + */ + for (const auto& request : requests) { + StepOnnxHiddenStateScoringContextRef scoringContext(dynamic_cast(request.context.get())); + + verify(scoreCache_.contains(scoringContext)); + auto const& scores = scoreCache_.get(scoringContext)->get(); + + result.scores.push_back(scores.at(request.nextToken)); + } + + return result; +} + +std::optional StatefulTransducerOnnxLabelScorer::computeScoreWithTime(LabelScorer::Request const& request) { + auto result = computeScoresWithTimes({request}); + if (not result) { + return {}; + } + return ScoreWithTime{result->scores.front(), result->timeframes.front()}; +} + +size_t StatefulTransducerOnnxLabelScorer::getMinActiveInputIndex(Core::CollapsedVector const& activeContexts) const { + return 0u; +} + +OnnxHiddenStateRef StatefulTransducerOnnxLabelScorer::updatedHiddenState(OnnxHiddenStateRef const& hiddenState, LabelIndex nextToken) { + /* + * Create session inputs + */ + std::vector> sessionInputs; + sessionInputs.emplace_back(updaterTokenName_, Onnx::Value::create(std::vector{static_cast(nextToken)})); + + for (auto const& [inputName, stateName] : updaterInputToStateNameMap_) { + sessionInputs.emplace_back(inputName, hiddenState->stateValueMap.at(stateName)); + } + + /* + * Run session + */ + std::vector sessionOutputNames; + std::vector stateNames; + for (auto const& [outputName, stateName] : updaterOutputToStateNameMap_) { + sessionOutputNames.push_back(outputName); + stateNames.push_back(stateName); + } + + std::vector sessionOutputs; + stateUpdaterOnnxModel_.session.run(std::move(sessionInputs), sessionOutputNames, sessionOutputs); + + /* + * Return resulting hidden state + */ + auto newHiddenState = Core::ref(new OnnxHiddenState(std::move(stateNames), std::move(sessionOutputs))); + + return newHiddenState; +} + +void StatefulTransducerOnnxLabelScorer::finalizeScoringContext(StepOnnxHiddenStateScoringContextRef const& scoringContext) { + // If this scoring context does not need finalization, don't change it + if (not scoringContext->requiresFinalize) { + return; + } + + verify(not scoringContext->labelSeq.empty()); + + scoringContext->hiddenState = updatedHiddenState(scoringContext->hiddenState, scoringContext->labelSeq.back()); + scoringContext->requiresFinalize = false; +} + +void StatefulTransducerOnnxLabelScorer::forwardBatch(std::vector const& scoringContextBatch) { + if (scoringContextBatch.empty()) { + return; + } + + /* + * Create session inputs + */ + auto inputFeatureDataView = getInput(scoringContextBatch.front()->currentStep); + f32 const* inputFeatureData = inputFeatureDataView->data(); + std::vector inputFeatureShape = {1ul, static_cast(inputFeatureDataView->size())}; + + std::vector> sessionInputs; + sessionInputs.emplace_back(scorerInputFeatureName_, Onnx::Value::create(inputFeatureData, inputFeatureShape)); + + for (auto const& [inputName, stateName] : scorerInputToStateNameMap_) { + // Collect a vector of individual state values of shape [1, *] and afterwards concatenate + // them to a batched state tensor of shape [B, *] + std::vector stateValues; + stateValues.reserve(scoringContextBatch.size()); + + for (size_t b = 0ul; b < scoringContextBatch.size(); ++b) { + auto scoringContext = scoringContextBatch[b]; + auto hiddenState = scoringContext->hiddenState; + stateValues.push_back(&hiddenState->stateValueMap.at(stateName)); + } + sessionInputs.emplace_back(inputName, Onnx::Value::concat(stateValues, 0)); + } + + /* + * Run session + */ + std::vector sessionOutputs; + scorerOnnxModel_.session.run(std::move(sessionInputs), {scorerScoresName_}, sessionOutputs); + + /* + * Put resulting scores into cache map + */ + for (size_t b = 0ul; b < scoringContextBatch.size(); ++b) { + std::vector scoreVec; + sessionOutputs.front().get(b, scoreVec); + scoreCache_.put(scoringContextBatch[b], std::move(scoreVec)); + } +} + +} // namespace Nn diff --git a/src/Nn/LabelScorer/StatefulTransducerOnnxLabelScorer.hh b/src/Nn/LabelScorer/StatefulTransducerOnnxLabelScorer.hh new file mode 100644 index 000000000..c0121ac0c --- /dev/null +++ b/src/Nn/LabelScorer/StatefulTransducerOnnxLabelScorer.hh @@ -0,0 +1,130 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef STATEFUL_TRANSDUCER_ONNX_LABEL_SCORER_HH +#define STATEFUL_TRANSDUCER_ONNX_LABEL_SCORER_HH + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "BufferedLabelScorer.hh" +#include "ScoringContext.hh" + +namespace Nn { + +/* + * Label Scorer that performs scoring by forwarding hidden states through an ONNX model. + * This Label Scorer requires three ONNX models: + * - A State Initializer which produces the hidden states for the first step + * - A State Updater which produces updated hidden states based on the previous hidden states and the next token + * - A Scorer which computes scores based on the current input feature and the hidden states + * + * The hidden states can be any number of ONNX tensors of any shape and type. + * Each ONNX model must have metadata that specifies the mapping of its input and output names to the corresponding state names. + * These state names need to be consistent over all three models. + * + * For example: + * - The State Initializer has output called "lstm_c" and {"lstm_c": "LSTM_C"} in its metadata + * - The State Updater has input "lstm_c_in", output "lstm_c_out" and {"lstm_c_in": "LSTM_C", "lstm_c_out": "LSTM_C"} in its metadata + * - The Scorer has input "lstm_c" and {"lstm_c": "LSTM_C"} in its metadata + * Here, "LSTM_C" is the state name and the same across all three models while the specific input/output names are arbitrary. + * + * The State Initializer must have all states as output. + * The State Updater must have a subset of states as input and all states as output. + * The Scorer must have a subset of states and a feature as input. + * + * A common use case for this Label Scorer would be a Transducer model with unlimited context. + * + * Note: This LabelScorer is similar to the `StatefulOnnxLabelScorer`. The difference is that in this one the ScoringContext also + * contains the current step and the input feature at the current step is fed to the Scorer. Furthermore, the state initializer + * and updater here only take tokens and no input features. + */ +class StatefulTransducerOnnxLabelScorer : public BufferedLabelScorer { + using Precursor = BufferedLabelScorer; + + static const Core::ParameterBool paramBlankUpdatesHistory; + static const Core::ParameterBool paramLoopUpdatesHistory; + static const Core::ParameterBool paramVerticalLabelTransition; + static const Core::ParameterInt paramMaxBatchSize; + static const Core::ParameterInt paramMaxCachedScores; + +public: + StatefulTransducerOnnxLabelScorer(const Core::Configuration& config); + virtual ~StatefulTransducerOnnxLabelScorer() = default; + + void reset() override; + + // If startLabelIndex is set, forward that through the state updater to obtain the start history + Core::Ref getInitialScoringContext() override; + + // Forward hidden-state through state-updater ONNX model + Core::Ref extendedScoringContext(LabelScorer::Request const& request) override; + + std::optional computeScoreWithTime(LabelScorer::Request const& request) override; + std::optional computeScoresWithTimes(std::vector const& requests) override; + +protected: + size_t getMinActiveInputIndex(Core::CollapsedVector const& activeContexts) const override; + +private: + // Forward a batch of histories through the ONNX model and put the resulting scores into the score cache + void forwardBatch(std::vector const& historyBatch); + + // Computes new hidden state based on previous hidden state and next token through state-updater call + OnnxHiddenStateRef updatedHiddenState(OnnxHiddenStateRef const& hiddenState, LabelIndex nextToken); + + // Replace hidden-state in scoringContext with an updated version that includes the last label + void finalizeScoringContext(StepOnnxHiddenStateScoringContextRef const& scoringContext); + + void setupEncoderStatesValue(); + void setupEncoderStatesSizeValue(); + + bool blankUpdatesHistory_; + bool loopUpdatesHistory_; + bool verticalLabelTransition_; + size_t maxBatchSize_; + + Onnx::Model scorerOnnxModel_; + Onnx::Model stateInitializerOnnxModel_; + Onnx::Model stateUpdaterOnnxModel_; + + StepOnnxHiddenStateScoringContextRef initialScoringContext_; + + // Map input/output names of onnx models to hidden state names taken from state initializer model + std::unordered_map initializerOutputToStateNameMap_; + std::unordered_map updaterInputToStateNameMap_; + std::unordered_map updaterOutputToStateNameMap_; + std::unordered_map scorerInputToStateNameMap_; + + std::string scorerInputFeatureName_; + std::string scorerScoresName_; + + std::string updaterTokenName_; + + Core::FIFOCache, ScoringContextHash, ScoringContextEq> scoreCache_; +}; + +} // namespace Nn + +#endif // STATEFUL_ONNX_LABEL_SCORER_HH diff --git a/src/Nn/Module.cc b/src/Nn/Module.cc index db673397f..2f4a1fc9f 100644 --- a/src/Nn/Module.cc +++ b/src/Nn/Module.cc @@ -24,6 +24,7 @@ #include "LabelScorer/NoContextOnnxLabelScorer.hh" #include "LabelScorer/NoOpLabelScorer.hh" #include "LabelScorer/StatefulOnnxLabelScorer.hh" +#include "LabelScorer/StatefulTransducerOnnxLabelScorer.hh" #include "Statistics.hh" #ifdef MODULE_NN @@ -128,6 +129,13 @@ Module_::Module_() [](Core::Configuration const& config) { return Core::ref(new StatefulOnnxLabelScorer(config)); }); + + // Compute scores based on input-feature and hidden-state where the hidden-state only depends on the token history + labelScorerFactory_.registerLabelScorer( + "stateful-transducer-onnx", + [](Core::Configuration const& config) { + return Core::ref(new StatefulTransducerOnnxLabelScorer(config)); + }); }; Module_::~Module_() {