diff --git a/src/Core/CollapsedVector.hh b/src/Core/CollapsedVector.hh index e6b4492c..96113328 100644 --- a/src/Core/CollapsedVector.hh +++ b/src/Core/CollapsedVector.hh @@ -45,6 +45,7 @@ public: inline void push_back(const T& value); inline const T& operator[](size_t idx) const; inline const T& at(size_t idx) const; + inline void set(size_t idx, const T& value); inline size_t size() const noexcept; inline void clear() noexcept; inline void reserve(size_t size); @@ -105,6 +106,21 @@ inline const T& CollapsedVector::at(size_t idx) const { return (*this)[idx]; } +template +inline void CollapsedVector::set(size_t idx, const T& value) { + if (idx >= logicalSize_) { + throw std::out_of_range("Trying to access illegal index of CollapsedVector"); + } + if (data_.size() != 1ul) { + data_[idx] = value; + data_.push_back(value); + } + else if (value != data_.front()) { + data_.resize(logicalSize_, data_.front()); + data_[idx] = value; + } +} + template inline size_t CollapsedVector::size() const noexcept { return logicalSize_; diff --git a/src/Nn/LabelScorer/CombineLabelScorer.cc b/src/Nn/LabelScorer/CombineLabelScorer.cc index 706c6c5c..242e8045 100644 --- a/src/Nn/LabelScorer/CombineLabelScorer.cc +++ b/src/Nn/LabelScorer/CombineLabelScorer.cc @@ -55,22 +55,6 @@ ScoringContextRef CombineLabelScorer::getInitialScoringContext() { return Core::ref(new CombineScoringContext(std::move(scoringContexts))); } -ScoringContextRef CombineLabelScorer::extendedScoringContext(Request const& request) { - auto combineContext = dynamic_cast(request.context.get()); - - std::vector extScoringContexts; - extScoringContexts.reserve(scaledScorers_.size()); - - auto scorerIt = scaledScorers_.begin(); - auto contextIt = combineContext->scoringContexts.begin(); - - for (; scorerIt != scaledScorers_.end(); ++scorerIt, ++contextIt) { - Request subRequest{*contextIt, request.nextToken, request.transitionType}; - extScoringContexts.push_back(scorerIt->scorer->extendedScoringContext(subRequest)); - } - return Core::ref(new CombineScoringContext(std::move(extScoringContexts))); -} - void CombineLabelScorer::cleanupCaches(Core::CollapsedVector const& activeContexts) { std::vector combineContexts; combineContexts.reserve(activeContexts.internalSize()); @@ -101,7 +85,23 @@ void CombineLabelScorer::addInputs(DataView const& input, size_t nTimesteps) { } } -std::optional CombineLabelScorer::computeScoreWithTime(Request const& request) { +ScoringContextRef CombineLabelScorer::extendedScoringContextInternal(Request const& request) { + auto combineContext = dynamic_cast(request.context.get()); + + std::vector extScoringContexts; + extScoringContexts.reserve(scaledScorers_.size()); + + auto scorerIt = scaledScorers_.begin(); + auto contextIt = combineContext->scoringContexts.begin(); + + for (; scorerIt != scaledScorers_.end(); ++scorerIt, ++contextIt) { + Request subRequest{*contextIt, request.nextToken, request.transitionType}; + extScoringContexts.push_back(scorerIt->scorer->extendedScoringContext(subRequest)); + } + return Core::ref(new CombineScoringContext(std::move(extScoringContexts))); +} + +std::optional CombineLabelScorer::computeScoreWithTimeInternal(Request const& request) { // Initialize accumulated result with zero-valued score and timestep ScoreWithTime accumResult{0.0, 0}; @@ -130,7 +130,11 @@ std::optional CombineLabelScorer::computeScoreWithTi return accumResult; } -std::optional CombineLabelScorer::computeScoresWithTimes(std::vector const& requests) { +std::optional CombineLabelScorer::computeScoresWithTimesInternal(std::vector const& requests) { + if (requests.empty()) { + return ScoresWithTimes{}; + } + // Initialize accumulated results with zero-valued scores and timesteps ScoresWithTimes accumResult{std::vector(requests.size(), 0.0), {requests.size(), 0}}; diff --git a/src/Nn/LabelScorer/CombineLabelScorer.hh b/src/Nn/LabelScorer/CombineLabelScorer.hh index a6baf688..2cb727ad 100644 --- a/src/Nn/LabelScorer/CombineLabelScorer.hh +++ b/src/Nn/LabelScorer/CombineLabelScorer.hh @@ -46,9 +46,6 @@ public: // Combine initial ScoringContexts from all sub-scorers ScoringContextRef getInitialScoringContext() override; - // Combine extended ScoringContexts from all sub-scorers - ScoringContextRef extendedScoringContext(Request const& request) override; - // Cleanup all sub-scorers void cleanupCaches(Core::CollapsedVector const& activeContexts) override; @@ -58,12 +55,6 @@ public: // Add inputs to all sub-scorers virtual void addInputs(DataView const& input, size_t nTimesteps) override; - // Compute weighted score of request with all sub-scorers - std::optional computeScoreWithTime(Request const& request) override; - - // Compute weighted scores of requests with all sub-scorers - std::optional computeScoresWithTimes(std::vector const& requests) override; - protected: struct ScaledLabelScorer { Core::Ref scorer; @@ -71,6 +62,20 @@ protected: }; std::vector scaledScorers_; + + // Combine extended ScoringContexts from all sub-scorers + ScoringContextRef extendedScoringContextInternal(Request const& request) override; + + // Compute weighted score of request with all sub-scorers + std::optional computeScoreWithTimeInternal(Request const& request) override; + + // Compute weighted scores of requests with all sub-scorers + std::optional computeScoresWithTimesInternal(std::vector const& requests) override; + + // Request filtering should happen in the sub-scorers, this one should let everything through + virtual TransitionPresetType defaultPreset() const override { + return TransitionPresetType::ALL; + } }; } // namespace Nn diff --git a/src/Nn/LabelScorer/EncoderDecoderLabelScorer.cc b/src/Nn/LabelScorer/EncoderDecoderLabelScorer.cc index 257c08e1..66ef464d 100644 --- a/src/Nn/LabelScorer/EncoderDecoderLabelScorer.cc +++ b/src/Nn/LabelScorer/EncoderDecoderLabelScorer.cc @@ -33,10 +33,6 @@ ScoringContextRef EncoderDecoderLabelScorer::getInitialScoringContext() { return decoder_->getInitialScoringContext(); } -ScoringContextRef EncoderDecoderLabelScorer::extendedScoringContext(Request const& request) { - return decoder_->extendedScoringContext(request); -} - void EncoderDecoderLabelScorer::cleanupCaches(Core::CollapsedVector const& activeContexts) { decoder_->cleanupCaches(activeContexts); } @@ -59,11 +55,19 @@ void EncoderDecoderLabelScorer::signalNoMoreFeatures() { decoder_->signalNoMoreFeatures(); } -std::optional EncoderDecoderLabelScorer::computeScoreWithTime(LabelScorer::Request const& request) { +ScoringContextRef EncoderDecoderLabelScorer::extendedScoringContextInternal(Request const& request) { + return decoder_->extendedScoringContext(request); +} + +std::optional EncoderDecoderLabelScorer::computeScoreWithTimeInternal(LabelScorer::Request const& request) { return decoder_->computeScoreWithTime(request); } -std::optional EncoderDecoderLabelScorer::computeScoresWithTimes(std::vector const& requests) { +std::optional EncoderDecoderLabelScorer::computeScoresWithTimesInternal(std::vector const& requests) { + if (requests.empty()) { + return ScoresWithTimes{}; + } + return decoder_->computeScoresWithTimes(requests); } diff --git a/src/Nn/LabelScorer/EncoderDecoderLabelScorer.hh b/src/Nn/LabelScorer/EncoderDecoderLabelScorer.hh index 204204aa..df6ff5da 100644 --- a/src/Nn/LabelScorer/EncoderDecoderLabelScorer.hh +++ b/src/Nn/LabelScorer/EncoderDecoderLabelScorer.hh @@ -46,9 +46,6 @@ public: // Get start context from decoder component ScoringContextRef getInitialScoringContext() override; - // Get extended context from decoder component - ScoringContextRef extendedScoringContext(Request const& request) override; - // Cleanup decoder component. Encoder is "self-cleaning" already in that it only stores outputs until they are // retrieved. void cleanupCaches(Core::CollapsedVector const& activeContexts) override; @@ -60,11 +57,20 @@ public: // Same as `addInput` but adds features for multiple timesteps at once void addInputs(DataView const& input, size_t nTimesteps) override; +protected: + // Get extended context from decoder component + ScoringContextRef extendedScoringContextInternal(Request const& request) override; + // Run request through decoder component - std::optional computeScoreWithTime(LabelScorer::Request const& request) override; + std::optional computeScoreWithTimeInternal(LabelScorer::Request const& request) override; // Run requests through decoder component - std::optional computeScoresWithTimes(std::vector const& requests) override; + std::optional computeScoresWithTimesInternal(std::vector const& requests) override; + + // Request filtering should happen in the decoder, this one should let everything through + virtual TransitionPresetType defaultPreset() const override { + return TransitionPresetType::ALL; + } private: Core::Ref encoder_; diff --git a/src/Nn/LabelScorer/FixedContextOnnxLabelScorer.cc b/src/Nn/LabelScorer/FixedContextOnnxLabelScorer.cc index 33d178f2..f52bd056 100644 --- a/src/Nn/LabelScorer/FixedContextOnnxLabelScorer.cc +++ b/src/Nn/LabelScorer/FixedContextOnnxLabelScorer.cc @@ -97,7 +97,32 @@ ScoringContextRef FixedContextOnnxLabelScorer::getInitialScoringContext() { return hist; } -ScoringContextRef FixedContextOnnxLabelScorer::extendedScoringContext(LabelScorer::Request const& request) { +size_t FixedContextOnnxLabelScorer::getMinActiveInputIndex(Core::CollapsedVector const& activeContexts) const { + auto minTimeIndex = Core::Type::max; + for (auto const& context : activeContexts.internalData()) { + SeqStepScoringContextRef stepHistory(dynamic_cast(context.get())); + minTimeIndex = std::min(minTimeIndex, stepHistory->currentStep); + } + + return minTimeIndex; +} + +void FixedContextOnnxLabelScorer::cleanupCaches(Core::CollapsedVector const& activeContexts) { + Precursor::cleanupCaches(activeContexts); + + std::unordered_set activeContextSet(activeContexts.internalData().begin(), activeContexts.internalData().end()); + + for (auto it = scoreCache_.begin(); it != scoreCache_.end();) { + if (activeContextSet.find(it->first) == activeContextSet.end()) { + it = scoreCache_.erase(it); + } + else { + ++it; + } + } +} + +ScoringContextRef FixedContextOnnxLabelScorer::extendedScoringContextInternal(LabelScorer::Request const& request) { SeqStepScoringContextRef context(dynamic_cast(request.context.get())); bool pushToken = false; @@ -144,22 +169,11 @@ ScoringContextRef FixedContextOnnxLabelScorer::extendedScoringContext(LabelScore return Core::ref(new SeqStepScoringContext(std::move(newLabelSeq), context->currentStep + timeIncrement)); } -void FixedContextOnnxLabelScorer::cleanupCaches(Core::CollapsedVector const& activeContexts) { - Precursor::cleanupCaches(activeContexts); - - std::unordered_set activeContextSet(activeContexts.internalData().begin(), activeContexts.internalData().end()); - - for (auto it = scoreCache_.begin(); it != scoreCache_.end();) { - if (activeContextSet.find(it->first) == activeContextSet.end()) { - it = scoreCache_.erase(it); - } - else { - ++it; - } +std::optional FixedContextOnnxLabelScorer::computeScoresWithTimesInternal(std::vector const& requests) { + if (requests.empty()) { + return ScoresWithTimes{}; } -} -std::optional FixedContextOnnxLabelScorer::computeScoresWithTimes(std::vector const& requests) { ScoresWithTimes result; result.scores.reserve(requests.size()); @@ -232,7 +246,7 @@ std::optional FixedContextOnnxLabelScorer::compute return result; } -std::optional FixedContextOnnxLabelScorer::computeScoreWithTime(LabelScorer::Request const& request) { +std::optional FixedContextOnnxLabelScorer::computeScoreWithTimeInternal(LabelScorer::Request const& request) { auto result = computeScoresWithTimes({request}); if (not result.has_value()) { return {}; @@ -240,16 +254,6 @@ std::optional FixedContextOnnxLabelScorer::computeSc return ScoreWithTime{result->scores.front(), result->timeframes.front()}; } -size_t FixedContextOnnxLabelScorer::getMinActiveInputIndex(Core::CollapsedVector const& activeContexts) const { - auto minTimeIndex = Core::Type::max; - for (auto const& context : activeContexts.internalData()) { - SeqStepScoringContextRef stepHistory(dynamic_cast(context.get())); - minTimeIndex = std::min(minTimeIndex, stepHistory->currentStep); - } - - return minTimeIndex; -} - void FixedContextOnnxLabelScorer::forwardBatch(std::vector const& contextBatch) { if (contextBatch.empty()) { return; diff --git a/src/Nn/LabelScorer/FixedContextOnnxLabelScorer.hh b/src/Nn/LabelScorer/FixedContextOnnxLabelScorer.hh index a5aa9162..6484ea57 100644 --- a/src/Nn/LabelScorer/FixedContextOnnxLabelScorer.hh +++ b/src/Nn/LabelScorer/FixedContextOnnxLabelScorer.hh @@ -47,24 +47,28 @@ public: // Initial scoring context contains step 0 and a history vector filled with the start label index ScoringContextRef getInitialScoringContext() override; + // Clean up input buffer as well as cached score vectors that are no longer needed + void cleanupCaches(Core::CollapsedVector const& activeContexts) override; + +protected: + size_t getMinActiveInputIndex(Core::CollapsedVector const& activeContexts) const override; + // May increment the step by 1 (except for vertical transitions) and may append the next token to the // history label sequence depending on the transition type and whether loops/blanks update the history // or not - ScoringContextRef extendedScoringContext(LabelScorer::Request const& request) override; - - // Clean up input buffer as well as cached score vectors that are no longer needed - void cleanupCaches(Core::CollapsedVector const& activeContexts) override; + ScoringContextRef extendedScoringContextInternal(LabelScorer::Request const& request) override; // If scores for the given scoring contexts are not yet cached, prepare and run an ONNX session to // compute the scores and cache them // Then, retreive scores from cache - std::optional computeScoresWithTimes(std::vector const& requests) override; + std::optional computeScoresWithTimesInternal(std::vector const& requests) override; // Uses `getScoresWithTimes` internally with some wrapping for vector packing/expansion - std::optional computeScoreWithTime(LabelScorer::Request const& request) override; + std::optional computeScoreWithTimeInternal(LabelScorer::Request const& request) override; -protected: - size_t getMinActiveInputIndex(Core::CollapsedVector const& activeContexts) const override; + virtual TransitionPresetType defaultPreset() const override { + return TransitionPresetType::TRANSDUCER; + } private: // Forward a batch of histories through the ONNX model and put the resulting scores into the score cache diff --git a/src/Nn/LabelScorer/LabelScorer.cc b/src/Nn/LabelScorer/LabelScorer.cc index 010aa56a..4e05969f 100644 --- a/src/Nn/LabelScorer/LabelScorer.cc +++ b/src/Nn/LabelScorer/LabelScorer.cc @@ -23,8 +23,30 @@ namespace Nn { * ============================= */ +const Core::Choice LabelScorer::choiceTransitionPreset( + "default", TransitionPresetType::DEFAULT, + "none", TransitionPresetType::NONE, + "ctc", TransitionPresetType::CTC, + "transducer", TransitionPresetType::TRANSDUCER, + "lm", TransitionPresetType::LM, + Core::Choice::endMark()); + +const Core::ParameterChoice LabelScorer::paramTransitionPreset( + "transition-preset", + &LabelScorer::choiceTransitionPreset, + "Preset for which transition types should be enabled for the label scorer. Disabled transition types get assigned score 0 and do not affect the ScoringContext.", + TransitionPresetType::DEFAULT); + +const Core::ParameterStringVector LabelScorer::paramExtraTransitionTypes( + "extra-transition-types", + "Transition types that should be enabled in addition to the ones given by the preset.", + ","); + LabelScorer::LabelScorer(const Core::Configuration& config) - : Core::Component(config) {} + : Core::Component(config), + enabledTransitionTypes_() { + enableTransitionTypes(config); +} void LabelScorer::addInputs(DataView const& input, size_t nTimesteps) { auto featureSize = input.size() / nTimesteps; @@ -33,7 +55,60 @@ void LabelScorer::addInputs(DataView const& input, size_t nTimesteps) { } } +ScoringContextRef LabelScorer::extendedScoringContext(Request const& request) { + if (enabledTransitionTypes_.contains(request.transitionType)) { + return extendedScoringContextInternal(request); + } + return request.context; +} + +std::optional LabelScorer::computeScoreWithTime(Request const& request) { + if (enabledTransitionTypes_.contains(request.transitionType)) { + return computeScoreWithTimeInternal(request); + } + return ScoreWithTime{0.0, 0}; +} + std::optional LabelScorer::computeScoresWithTimes(std::vector const& requests) { + // First, collect all requests for which the transition type is not ignored + std::vector nonIgnoredRequests; + nonIgnoredRequests.reserve(requests.size()); + + std::vector nonIgnoredRequestIndices; + nonIgnoredRequestIndices.reserve(requests.size()); + + for (size_t requestIndex = 0ul; requestIndex < requests.size(); ++requestIndex) { + auto const& request = requests[requestIndex]; + if (enabledTransitionTypes_.contains(request.transitionType)) { + nonIgnoredRequests.push_back(request); + nonIgnoredRequestIndices.push_back(requestIndex); + } + } + + // Compute scores for non-ignored requests + auto nonIgnoredResults = computeScoresWithTimesInternal(nonIgnoredRequests); + if (not nonIgnoredResults) { + return {}; + } + + // Interleave actual results with 0 scores for requests with ignored transition types + ScoresWithTimes result{ + .scores = std::vector(requests.size(), 0.0), + .timeframes{requests.size(), 0}}; + for (size_t i = 0ul; i < nonIgnoredRequestIndices.size(); ++i) { + auto requestIndex = nonIgnoredRequestIndices[i]; + result.scores[requestIndex] = nonIgnoredResults->scores[i]; + result.timeframes.set(requestIndex, nonIgnoredResults->timeframes[i]); + } + + return result; +} + +std::optional LabelScorer::computeScoresWithTimesInternal(std::vector const& requests) { + if (requests.empty()) { + return ScoresWithTimes{}; + } + // By default, just loop over the non-batched `computeScoreWithTime` and collect the results ScoresWithTimes result; @@ -51,4 +126,66 @@ std::optional LabelScorer::computeScoresWithTimes( return result; } +void LabelScorer::enableTransitionTypes(Core::Configuration const& config) { + auto preset = paramTransitionPreset(config); + if (preset == TransitionPresetType::DEFAULT) { + preset = defaultPreset(); + } + verify(preset != TransitionPresetType::DEFAULT); + + switch (preset) { + case TransitionPresetType::NONE: + break; + case TransitionPresetType::ALL: + for (auto const& [_, transitionType] : transitionTypeArray_) { + enabledTransitionTypes_.insert(transitionType); + } + break; + case TransitionPresetType::CTC: + enabledTransitionTypes_ = { + LABEL_TO_LABEL, + LABEL_LOOP, + LABEL_TO_BLANK, + BLANK_TO_LABEL, + BLANK_LOOP, + INITIAL_LABEL, + INITIAL_BLANK, + }; + break; + case TransitionPresetType::TRANSDUCER: + enabledTransitionTypes_ = { + LABEL_TO_LABEL, + LABEL_TO_BLANK, + BLANK_TO_LABEL, + BLANK_LOOP, + INITIAL_LABEL, + INITIAL_BLANK, + }; + break; + case TransitionPresetType::LM: + enabledTransitionTypes_ = { + LABEL_TO_LABEL, + INITIAL_LABEL, + }; + break; + } + + auto extraTransitionTypeStrings = paramExtraTransitionTypes(config); + for (auto const& transitionTypeString : extraTransitionTypeStrings) { + auto it = std::find_if(transitionTypeArray_.begin(), + transitionTypeArray_.end(), + [&](auto const& entry) { return entry.first == transitionTypeString; }); + if (it != transitionTypeArray_.end()) { + enabledTransitionTypes_.insert(it->second); + } + else { + error() << "Extra transition type name '" << transitionTypeString << "' is not a valid identifier"; + } + } + + if (enabledTransitionTypes_.empty()) { + error() << "Label scorer has no enabled transition types. Activate a preset and/or add extra transition types that should be considered."; + } +} + } // namespace Nn diff --git a/src/Nn/LabelScorer/LabelScorer.hh b/src/Nn/LabelScorer/LabelScorer.hh index ed6072c0..aee6f046 100644 --- a/src/Nn/LabelScorer/LabelScorer.hh +++ b/src/Nn/LabelScorer/LabelScorer.hh @@ -76,6 +76,11 @@ class LabelScorer : public virtual Core::Component, public: typedef Search::Score Score; + static const Core::Choice choiceTransitionPreset; + static const Core::ParameterChoice paramTransitionPreset; + + static const Core::ParameterStringVector paramExtraTransitionTypes; + enum TransitionType { LABEL_TO_LABEL, LABEL_LOOP, @@ -87,6 +92,15 @@ public: numTypes, // must remain at the end }; + enum TransitionPresetType { + DEFAULT, + NONE, + ALL, + CTC, + TRANSDUCER, + LM, + }; + // Request for scoring or context extension struct Request { ScoringContextRef context; @@ -120,7 +134,7 @@ public: virtual ScoringContextRef getInitialScoringContext() = 0; // Creates a copy of the context in the request that is extended using the given token and transition type - virtual ScoringContextRef extendedScoringContext(Request const& request) = 0; + ScoringContextRef extendedScoringContext(Request const& request); // Given a collection of currently active contexts, this function can clean up values in any internal caches // or buffers that are saved for scoring contexts which no longer are active. @@ -136,13 +150,12 @@ public: // Return score and timeframe index of the corresponding output // May not return a value if the LabelScorer is not ready to score the request yet // (e.g. not enough features received) - virtual std::optional computeScoreWithTime(Request const& request) = 0; + std::optional computeScoreWithTime(Request const& request); // Perform scoring computation for a batch of requests // May be implemented more efficiently than iterated calls of `getScoreWithTime` // Return two vectors: one vector with scores and one vector with times - // By default loops over the single-request version - virtual std::optional computeScoresWithTimes(std::vector const& requests); + std::optional computeScoresWithTimes(std::vector const& requests); protected: inline static constexpr auto transitionTypeArray_ = std::to_array>({ @@ -155,6 +168,24 @@ protected: {"initial-blank", INITIAL_BLANK}, }); static_assert(transitionTypeArray_.size() == TransitionType::numTypes, "transitionTypeArray size must match number of TransitionType values"); + + // The public versions of these functions are implemented in this base class and handle the ignoring of transition types. + // These `Internal` versions contain the actual logic and should be overridden in child classes. + + virtual ScoringContextRef extendedScoringContextInternal(Request const& request) = 0; + virtual std::optional computeScoreWithTimeInternal(Request const& request) = 0; + + // By default loops over the single-request version + virtual std::optional computeScoresWithTimesInternal(std::vector const& requests); + + virtual TransitionPresetType defaultPreset() const { + return TransitionPresetType::NONE; + } + +private: + std::unordered_set enabledTransitionTypes_; + + void enableTransitionTypes(Core::Configuration const& config); }; } // namespace Nn diff --git a/src/Nn/LabelScorer/NoContextOnnxLabelScorer.cc b/src/Nn/LabelScorer/NoContextOnnxLabelScorer.cc index 7a41ebf3..959343c9 100644 --- a/src/Nn/LabelScorer/NoContextOnnxLabelScorer.cc +++ b/src/Nn/LabelScorer/NoContextOnnxLabelScorer.cc @@ -53,11 +53,6 @@ ScoringContextRef NoContextOnnxLabelScorer::getInitialScoringContext() { return Core::ref(new StepScoringContext()); } -ScoringContextRef NoContextOnnxLabelScorer::extendedScoringContext(LabelScorer::Request const& request) { - StepScoringContextRef context(dynamic_cast(request.context.get())); - return Core::ref(new StepScoringContext(context->currentStep + 1)); -} - void NoContextOnnxLabelScorer::cleanupCaches(Core::CollapsedVector const& activeContexts) { Precursor::cleanupCaches(activeContexts); @@ -73,7 +68,26 @@ void NoContextOnnxLabelScorer::cleanupCaches(Core::CollapsedVector NoContextOnnxLabelScorer::computeScoresWithTimes(std::vector const& requests) { +size_t NoContextOnnxLabelScorer::getMinActiveInputIndex(Core::CollapsedVector const& activeContexts) const { + auto minTimeIndex = Core::Type::max; + for (auto const& context : activeContexts.internalData()) { + StepScoringContextRef stepHistory(dynamic_cast(context.get())); + minTimeIndex = std::min(minTimeIndex, stepHistory->currentStep); + } + + return minTimeIndex; +} + +ScoringContextRef NoContextOnnxLabelScorer::extendedScoringContextInternal(LabelScorer::Request const& request) { + StepScoringContextRef context(dynamic_cast(request.context.get())); + return Core::ref(new StepScoringContext(context->currentStep + 1)); +} + +std::optional NoContextOnnxLabelScorer::computeScoresWithTimesInternal(std::vector const& requests) { + if (requests.empty()) { + return ScoresWithTimes{}; + } + ScoresWithTimes result; result.scores.reserve(requests.size()); @@ -115,7 +129,7 @@ std::optional NoContextOnnxLabelScorer::computeSco return result; } -std::optional NoContextOnnxLabelScorer::computeScoreWithTime(LabelScorer::Request const& request) { +std::optional NoContextOnnxLabelScorer::computeScoreWithTimeInternal(LabelScorer::Request const& request) { auto result = computeScoresWithTimes({request}); if (not result.has_value()) { return {}; @@ -123,16 +137,6 @@ std::optional NoContextOnnxLabelScorer::computeScore return ScoreWithTime{result->scores.front(), result->timeframes.front()}; } -size_t NoContextOnnxLabelScorer::getMinActiveInputIndex(Core::CollapsedVector const& activeContexts) const { - auto minTimeIndex = Core::Type::max; - for (auto const& context : activeContexts.internalData()) { - StepScoringContextRef stepHistory(dynamic_cast(context.get())); - minTimeIndex = std::min(minTimeIndex, stepHistory->currentStep); - } - - return minTimeIndex; -} - void NoContextOnnxLabelScorer::forwardContext(StepScoringContextRef const& context) { /* * Create session inputs diff --git a/src/Nn/LabelScorer/NoContextOnnxLabelScorer.hh b/src/Nn/LabelScorer/NoContextOnnxLabelScorer.hh index 1eecf824..283c313e 100644 --- a/src/Nn/LabelScorer/NoContextOnnxLabelScorer.hh +++ b/src/Nn/LabelScorer/NoContextOnnxLabelScorer.hh @@ -44,32 +44,36 @@ public: // Initial scoring context contains step 0 ScoringContextRef getInitialScoringContext() override; - // Increment the step by 1 - ScoringContextRef extendedScoringContext(LabelScorer::Request const& request) override; - // Clean up input buffer as well as cached score vectors that are no longer needed void cleanupCaches(Core::CollapsedVector const& activeContexts) override; +protected: + size_t getMinActiveInputIndex(Core::CollapsedVector const& activeContexts) const override; + + // Increment the step by 1 + ScoringContextRef extendedScoringContextInternal(LabelScorer::Request const& request) override; + // If scores for the given scoring contexts are not yet cached, prepare and run an ONNX session to // compute the scores and cache them // Then, retreive scores from cache - std::optional computeScoresWithTimes(std::vector const& requests) override; + std::optional computeScoresWithTimesInternal(std::vector const& requests) override; // Uses `getScoresWithTimes` internally with some wrapping for vector packing/expansion - std::optional computeScoreWithTime(LabelScorer::Request const& request) override; + std::optional computeScoreWithTimeInternal(LabelScorer::Request const& request) override; -protected: - size_t getMinActiveInputIndex(Core::CollapsedVector const& activeContexts) const override; + virtual TransitionPresetType defaultPreset() const override { + return TransitionPresetType::CTC; + } private: - void forwardContext(StepScoringContextRef const& context); - Onnx::Model onnxModel_; std::string inputFeatureName_; std::string scoresName_; std::unordered_map, ScoringContextHash, ScoringContextEq> scoreCache_; + + void forwardContext(StepScoringContextRef const& context); }; } // namespace Nn diff --git a/src/Nn/LabelScorer/NoOpLabelScorer.cc b/src/Nn/LabelScorer/NoOpLabelScorer.cc index 8a14300f..3a151f69 100644 --- a/src/Nn/LabelScorer/NoOpLabelScorer.cc +++ b/src/Nn/LabelScorer/NoOpLabelScorer.cc @@ -25,12 +25,22 @@ ScoringContextRef StepwiseNoOpLabelScorer::getInitialScoringContext() { return Core::ref(new StepScoringContext()); } -ScoringContextRef StepwiseNoOpLabelScorer::extendedScoringContext(LabelScorer::Request const& request) { +size_t StepwiseNoOpLabelScorer::getMinActiveInputIndex(Core::CollapsedVector const& activeContexts) const { + auto minInputIndex = Core::Type::max; + for (auto const& context : activeContexts.internalData()) { + StepScoringContextRef stepHistory(dynamic_cast(context.get())); + minInputIndex = std::min(minInputIndex, static_cast(stepHistory->currentStep)); + } + + return minInputIndex; +} + +ScoringContextRef StepwiseNoOpLabelScorer::extendedScoringContextInternal(LabelScorer::Request const& request) { StepScoringContextRef stepHistory(dynamic_cast(request.context.get())); return Core::ref(new StepScoringContext(stepHistory->currentStep + 1)); } -std::optional StepwiseNoOpLabelScorer::computeScoreWithTime(LabelScorer::Request const& request) { +std::optional StepwiseNoOpLabelScorer::computeScoreWithTimeInternal(LabelScorer::Request const& request) { StepScoringContextRef stepHistory(dynamic_cast(request.context.get())); auto input = getInput(stepHistory->currentStep); if (not input) { @@ -40,14 +50,4 @@ std::optional StepwiseNoOpLabelScorer::computeScoreW return ScoreWithTime{(*input)[request.nextToken], stepHistory->currentStep}; } -size_t StepwiseNoOpLabelScorer::getMinActiveInputIndex(Core::CollapsedVector const& activeContexts) const { - auto minInputIndex = Core::Type::max; - for (auto const& context : activeContexts.internalData()) { - StepScoringContextRef stepHistory(dynamic_cast(context.get())); - minInputIndex = std::min(minInputIndex, static_cast(stepHistory->currentStep)); - } - - return minInputIndex; -} - } // namespace Nn diff --git a/src/Nn/LabelScorer/NoOpLabelScorer.hh b/src/Nn/LabelScorer/NoOpLabelScorer.hh index ff9ddd2d..2ff863b6 100644 --- a/src/Nn/LabelScorer/NoOpLabelScorer.hh +++ b/src/Nn/LabelScorer/NoOpLabelScorer.hh @@ -36,14 +36,18 @@ public: // Initial scoring context just contains step 0. ScoringContextRef getInitialScoringContext() override; +protected: + size_t getMinActiveInputIndex(Core::CollapsedVector const& activeContexts) const override; + // Scoring context with step incremented by 1. - ScoringContextRef extendedScoringContext(LabelScorer::Request const& request) override; + ScoringContextRef extendedScoringContextInternal(LabelScorer::Request const& request) override; // Gets the buffered score for the requested token at the requested step - std::optional computeScoreWithTime(LabelScorer::Request const& request) override; + std::optional computeScoreWithTimeInternal(LabelScorer::Request const& request) override; -protected: - size_t getMinActiveInputIndex(Core::CollapsedVector const& activeContexts) const override; + virtual TransitionPresetType defaultPreset() const override { + return TransitionPresetType::CTC; + } }; } // namespace Nn diff --git a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc index be5bc817..8f4a3a1f 100644 --- a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc +++ b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc @@ -200,7 +200,22 @@ Core::Ref StatefulOnnxLabelScorer::getInitialScoringContex return Core::ref(new OnnxHiddenStateScoringContext()); // Sentinel empty Ref as initial hidden state } -Core::Ref StatefulOnnxLabelScorer::extendedScoringContext(LabelScorer::Request const& request) { +void StatefulOnnxLabelScorer::addInput(DataView const& input) { + Precursor::addInput(input); + + initialHiddenState_ = OnnxHiddenStateRef(); + + if (not encoderStatesValue_.empty()) { // Any previously computed hidden state values are outdated now so reset them + encoderStatesValue_ = Onnx::Value(); + encoderStatesSizeValue_ = Onnx::Value(); + } +} + +size_t StatefulOnnxLabelScorer::getMinActiveInputIndex(Core::CollapsedVector const& activeContexts) const { + return 0u; +} + +Core::Ref StatefulOnnxLabelScorer::extendedScoringContextInternal(LabelScorer::Request const& request) { OnnxHiddenStateScoringContextRef history(dynamic_cast(request.context.get())); bool updateState = false; @@ -243,18 +258,11 @@ Core::Ref StatefulOnnxLabelScorer::extendedScoringContext( return Core::ref(new OnnxHiddenStateScoringContext(std::move(newLabelSeq), newHiddenState)); } -void StatefulOnnxLabelScorer::addInput(DataView const& input) { - Precursor::addInput(input); - - initialHiddenState_ = OnnxHiddenStateRef(); - - if (not encoderStatesValue_.empty()) { // Any previously computed hidden state values are outdated now so reset them - encoderStatesValue_ = Onnx::Value(); - encoderStatesSizeValue_ = Onnx::Value(); +std::optional StatefulOnnxLabelScorer::computeScoresWithTimesInternal(std::vector const& requests) { + if (requests.empty()) { + return ScoresWithTimes{}; } -} -std::optional StatefulOnnxLabelScorer::computeScoresWithTimes(std::vector const& requests) { if ((initializerEncoderStatesName_ != "" or initializerEncoderStatesSizeName_ != "" or updaterEncoderStatesName_ != "" or updaterEncoderStatesSizeName_ != "") and (expectMoreFeatures_ or bufferSize() == 0)) { // Only allow scoring once all encoder states have been passed return {}; @@ -304,7 +312,7 @@ std::optional StatefulOnnxLabelScorer::computeScor return result; } -std::optional StatefulOnnxLabelScorer::computeScoreWithTime(LabelScorer::Request const& request) { +std::optional StatefulOnnxLabelScorer::computeScoreWithTimeInternal(LabelScorer::Request const& request) { auto result = computeScoresWithTimes({request}); if (not result) { return {}; @@ -312,10 +320,6 @@ std::optional StatefulOnnxLabelScorer::computeScoreW return ScoreWithTime{result->scores.front(), result->timeframes.front()}; } -size_t StatefulOnnxLabelScorer::getMinActiveInputIndex(Core::CollapsedVector const& activeContexts) const { - return 0u; -} - void StatefulOnnxLabelScorer::setupEncoderStatesValue() { if (not encoderStatesValue_.empty()) { return; diff --git a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.hh b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.hh index 29f1777c..9ea4557a 100644 --- a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.hh +++ b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.hh @@ -74,18 +74,22 @@ public: // If startLabelIndex is set, forward that through the state updater to obtain the start history Core::Ref getInitialScoringContext() override; - // Forward hidden-state through state-updater ONNX model - Core::Ref extendedScoringContext(LabelScorer::Request const& request) override; - // Add a single encoder outputs to buffer void addInput(DataView const& input) override; - std::optional computeScoreWithTime(LabelScorer::Request const& request) override; - std::optional computeScoresWithTimes(std::vector const& requests) override; - protected: size_t getMinActiveInputIndex(Core::CollapsedVector const& activeContexts) const override; + // Forward hidden-state through state-updater ONNX model + Core::Ref extendedScoringContextInternal(LabelScorer::Request const& request) override; + + std::optional computeScoreWithTimeInternal(LabelScorer::Request const& request) override; + std::optional computeScoresWithTimesInternal(std::vector const& requests) override; + + virtual TransitionPresetType defaultPreset() const override { + return TransitionPresetType::LM; + } + private: // Forward a batch of histories through the ONNX model and put the resulting scores into the score cache void forwardBatch(std::vector const& historyBatch); diff --git a/src/Nn/LabelScorer/TransitionLabelScorer.cc b/src/Nn/LabelScorer/TransitionLabelScorer.cc index f7afac18..3c4f40bd 100644 --- a/src/Nn/LabelScorer/TransitionLabelScorer.cc +++ b/src/Nn/LabelScorer/TransitionLabelScorer.cc @@ -42,10 +42,6 @@ ScoringContextRef TransitionLabelScorer::getInitialScoringContext() { return baseLabelScorer_->getInitialScoringContext(); } -ScoringContextRef TransitionLabelScorer::extendedScoringContext(LabelScorer::Request const& request) { - return baseLabelScorer_->extendedScoringContext(request); -} - void TransitionLabelScorer::cleanupCaches(Core::CollapsedVector const& activeContexts) { baseLabelScorer_->cleanupCaches(activeContexts); } @@ -58,7 +54,11 @@ void TransitionLabelScorer::addInputs(DataView const& input, size_t nTimesteps) baseLabelScorer_->addInputs(input, nTimesteps); } -std::optional TransitionLabelScorer::computeScoreWithTime(LabelScorer::Request const& request) { +ScoringContextRef TransitionLabelScorer::extendedScoringContextInternal(LabelScorer::Request const& request) { + return baseLabelScorer_->extendedScoringContext(request); +} + +std::optional TransitionLabelScorer::computeScoreWithTimeInternal(LabelScorer::Request const& request) { auto result = baseLabelScorer_->computeScoreWithTime(request); if (result) { result->score += transitionScores_[request.transitionType]; @@ -66,7 +66,11 @@ std::optional TransitionLabelScorer::computeScoreWit return result; } -std::optional TransitionLabelScorer::computeScoresWithTimes(std::vector const& requests) { +std::optional TransitionLabelScorer::computeScoresWithTimesInternal(std::vector const& requests) { + if (requests.empty()) { + return ScoresWithTimes{}; + } + auto results = baseLabelScorer_->computeScoresWithTimes(requests); if (results) { for (size_t i = 0ul; i < requests.size(); ++i) { diff --git a/src/Nn/LabelScorer/TransitionLabelScorer.hh b/src/Nn/LabelScorer/TransitionLabelScorer.hh index a8bc4af0..ea5df749 100644 --- a/src/Nn/LabelScorer/TransitionLabelScorer.hh +++ b/src/Nn/LabelScorer/TransitionLabelScorer.hh @@ -41,9 +41,6 @@ public: // Initial context of base scorer ScoringContextRef getInitialScoringContext() override; - // Extend context via base scorer - ScoringContextRef extendedScoringContext(Request const& request) override; - // Clean up base scorer void cleanupCaches(Core::CollapsedVector const& activeContexts) override; @@ -53,11 +50,19 @@ public: // Add inputs to sub-scorer void addInputs(DataView const& input, size_t nTimesteps) override; +protected: + // Extend context via base scorer + ScoringContextRef extendedScoringContextInternal(Request const& request) override; + // Compute score of base scorer and add transition score based on transition type of the request - std::optional computeScoreWithTime(Request const& request) override; + std::optional computeScoreWithTimeInternal(Request const& request) override; // Compute scores of base scorer and add transition scores based on transition types of the requests - std::optional computeScoresWithTimes(std::vector const& requests) override; + std::optional computeScoresWithTimesInternal(std::vector const& requests) override; + + virtual TransitionPresetType defaultPreset() const override { + return TransitionPresetType::ALL; + } private: std::unordered_map transitionScores_;