-
Notifications
You must be signed in to change notification settings - Fork 16
Partial enabling of transition types for different label scorers #148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 2 commits
7502b82
7430001
2a6272e
7e325e1
a276136
d2d78fe
303fa46
ddd75c7
b856c1e
70699c0
5b89d0f
8b27b19
1e877c7
3dcadee
b9d919b
23df463
a437c04
8337fea
0cfdf3d
85522b5
4098ad4
0d34085
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,6 +67,11 @@ protected: | |
| // Run requests through decoder component | ||
| std::optional<LabelScorer::ScoresWithTimes> computeScoresWithTimesInternal(std::vector<LabelScorer::Request> const& requests) override; | ||
|
|
||
| // Request filtering should happen in the decoder, this one should let everything through | ||
| virtual TransitionPresetType defaultPreset() const override { | ||
| return TransitionPresetType::ALL; | ||
| } | ||
|
Comment on lines
+71
to
+73
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see others |
||
|
|
||
| private: | ||
| Core::Ref<Encoder> encoder_; | ||
| Core::Ref<LabelScorer> decoder_; | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -23,26 +23,29 @@ namespace Nn { | |||||||
| * ============================= | ||||||||
| */ | ||||||||
|
|
||||||||
| const Core::ParameterStringVector LabelScorer::paramIgnoredTransitionTypes( | ||||||||
| "ignored-transition-types", | ||||||||
| "Transition types that should be ignored by the label scorer (i.e. get assigned score 0 and do not affect the ScoringContext)", | ||||||||
| const Core::Choice LabelScorer::choiceTransitionPreset( | ||||||||
| "default", TransitionPresetType::DEFAULT, | ||||||||
| "none", TransitionPresetType::NONE, | ||||||||
| "ctc", TransitionPresetType::CTC, | ||||||||
| "transducer", TransitionPresetType::TRANSDUCER, | ||||||||
| "lm", TransitionPresetType::LM, | ||||||||
| Core::Choice::endMark()); | ||||||||
|
|
||||||||
| const Core::ParameterChoice LabelScorer::paramTransitionPreset( | ||||||||
| "transition-preset", | ||||||||
| &LabelScorer::choiceTransitionPreset, | ||||||||
| "Preset for which transition types should be enabled for the label scorer. Disabled transition types get assigned score 0 and do not affect the ScoringContext.", | ||||||||
| TransitionPresetType::DEFAULT); | ||||||||
|
|
||||||||
| const Core::ParameterStringVector LabelScorer::paramExtraTransitionTypes( | ||||||||
| "extra-transition-types", | ||||||||
| "Transition types that should be enabled in addition to the ones given by the preset.", | ||||||||
| ","); | ||||||||
|
|
||||||||
| LabelScorer::LabelScorer(const Core::Configuration& config) | ||||||||
| : Core::Component(config), | ||||||||
| ignoredTransitionTypes_() { | ||||||||
| auto ignoredTransitionTypeStrings = paramIgnoredTransitionTypes(config); | ||||||||
| for (auto const& transitionTypeString : ignoredTransitionTypeStrings) { | ||||||||
| auto it = std::find_if(transitionTypeArray_.begin(), | ||||||||
| transitionTypeArray_.end(), | ||||||||
| [&](auto const& entry) { return entry.first == transitionTypeString; }); | ||||||||
| if (it != transitionTypeArray_.end()) { | ||||||||
| ignoredTransitionTypes_.insert(it->second); | ||||||||
| } | ||||||||
| else { | ||||||||
| error() << "Ignored transition type name '" << transitionTypeString << "' is not a valid identifier"; | ||||||||
| } | ||||||||
| } | ||||||||
| enabledTransitionTypes_() { | ||||||||
| enableTransitionTypes(config); | ||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I have to withdraw may approve. I just found a bug. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Eugen: yes pelase |
||||||||
| } | ||||||||
|
|
||||||||
| void LabelScorer::addInputs(DataView const& input, size_t nTimesteps) { | ||||||||
|
|
@@ -53,17 +56,17 @@ void LabelScorer::addInputs(DataView const& input, size_t nTimesteps) { | |||||||
| } | ||||||||
|
|
||||||||
| ScoringContextRef LabelScorer::extendedScoringContext(Request const& request) { | ||||||||
| if (ignoredTransitionTypes_.contains(request.transitionType)) { | ||||||||
| return request.context; | ||||||||
| if (enabledTransitionTypes_.contains(request.transitionType)) { | ||||||||
| return extendedScoringContextInternal(request); | ||||||||
| } | ||||||||
| return extendedScoringContextInternal(request); | ||||||||
| return request.context; | ||||||||
| } | ||||||||
|
|
||||||||
| std::optional<LabelScorer::ScoreWithTime> LabelScorer::computeScoreWithTime(Request const& request) { | ||||||||
| if (ignoredTransitionTypes_.contains(request.transitionType)) { | ||||||||
| return ScoreWithTime{0.0, 0}; | ||||||||
| if (enabledTransitionTypes_.contains(request.transitionType)) { | ||||||||
| return computeScoreWithTimeInternal(request); | ||||||||
| } | ||||||||
| return computeScoreWithTimeInternal(request); | ||||||||
| return ScoreWithTime{0.0, 0}; | ||||||||
| } | ||||||||
|
|
||||||||
| std::optional<LabelScorer::ScoresWithTimes> LabelScorer::computeScoresWithTimes(std::vector<LabelScorer::Request> const& requests) { | ||||||||
|
|
@@ -76,7 +79,7 @@ std::optional<LabelScorer::ScoresWithTimes> LabelScorer::computeScoresWithTimes( | |||||||
|
|
||||||||
| for (size_t requestIndex = 0ul; requestIndex < requests.size(); ++requestIndex) { | ||||||||
| auto const& request = requests[requestIndex]; | ||||||||
| if (not ignoredTransitionTypes_.contains(request.transitionType)) { | ||||||||
| if (enabledTransitionTypes_.contains(request.transitionType)) { | ||||||||
| nonIgnoredRequests.push_back(request); | ||||||||
| nonIgnoredRequestIndices.push_back(requestIndex); | ||||||||
| } | ||||||||
|
|
@@ -89,12 +92,13 @@ std::optional<LabelScorer::ScoresWithTimes> LabelScorer::computeScoresWithTimes( | |||||||
| } | ||||||||
|
|
||||||||
| // Interleave actual results with 0 scores for requests with ignored transition types | ||||||||
| ScoresWithTimes result{{requests.size(), 0.0}, {requests.size(), 0}}; | ||||||||
| ScoresWithTimes result{ | ||||||||
| .scores = std::vector<Score>(requests.size(), 0.0), | ||||||||
| .timeframes{requests.size(), 0}}; | ||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| for (size_t i = 0ul; i < nonIgnoredRequestIndices.size(); ++i) { | ||||||||
| auto requestResult = nonIgnoredResults[i]; | ||||||||
| auto requestIndex = nonIgnoredRequestIndices[i]; | ||||||||
| result.scores[requestIndex] = requestResult.score; | ||||||||
| result.timeframes.set(requestIndex, requestResult.timeframe); | ||||||||
| result.scores[requestIndex] = nonIgnoredResults->scores[i]; | ||||||||
| result.timeframes.set(requestIndex, nonIgnoredResults->timeframes[i]); | ||||||||
| } | ||||||||
|
|
||||||||
| return result; | ||||||||
|
|
@@ -122,4 +126,66 @@ std::optional<LabelScorer::ScoresWithTimes> LabelScorer::computeScoresWithTimesI | |||||||
| return result; | ||||||||
| } | ||||||||
|
|
||||||||
| void LabelScorer::enableTransitionTypes(Core::Configuration const& config) { | ||||||||
| auto preset = paramTransitionPreset(config); | ||||||||
| if (preset == TransitionPresetType::DEFAULT) { | ||||||||
| preset = defaultPreset(); | ||||||||
| } | ||||||||
| verify(preset != TransitionPresetType::DEFAULT); | ||||||||
|
|
||||||||
| switch (preset) { | ||||||||
| case TransitionPresetType::NONE: | ||||||||
| break; | ||||||||
| case TransitionPresetType::ALL: | ||||||||
| for (auto const& [_, transitionType] : transitionTypeArray_) { | ||||||||
| enabledTransitionTypes_.insert(transitionType); | ||||||||
| } | ||||||||
| break; | ||||||||
| case TransitionPresetType::CTC: | ||||||||
| enabledTransitionTypes_ = { | ||||||||
| LABEL_TO_LABEL, | ||||||||
| LABEL_LOOP, | ||||||||
| LABEL_TO_BLANK, | ||||||||
| BLANK_TO_LABEL, | ||||||||
| BLANK_LOOP, | ||||||||
| INITIAL_LABEL, | ||||||||
| INITIAL_BLANK, | ||||||||
| }; | ||||||||
| break; | ||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's related to this PR #152, but TRANSDUCER and LM need additonally SENTENCE_END transition types. |
||||||||
| case TransitionPresetType::TRANSDUCER: | ||||||||
| enabledTransitionTypes_ = { | ||||||||
| LABEL_TO_LABEL, | ||||||||
| LABEL_TO_BLANK, | ||||||||
| BLANK_TO_LABEL, | ||||||||
| BLANK_LOOP, | ||||||||
| INITIAL_LABEL, | ||||||||
| INITIAL_BLANK, | ||||||||
| }; | ||||||||
| break; | ||||||||
| case TransitionPresetType::LM: | ||||||||
| enabledTransitionTypes_ = { | ||||||||
| LABEL_TO_LABEL, | ||||||||
| INITIAL_LABEL, | ||||||||
|
Comment on lines
+165
to
+168
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The AedTreeBuilder and labelsync searches are still WIP and not merged yet, but at some point we will also need a preset for AED and I guess this will be the same as this one. So will we then just add a preset which is exactly the same, just with a different name? And if yes, should LM or AED be the default preset of the StatefulOnnxLabelScorer? I mean in the end it's the same, but it might become confusing because of the naming. |
||||||||
| }; | ||||||||
| break; | ||||||||
| } | ||||||||
|
|
||||||||
| auto extraTransitionTypeStrings = paramExtraTransitionTypes(config); | ||||||||
| for (auto const& transitionTypeString : extraTransitionTypeStrings) { | ||||||||
| auto it = std::find_if(transitionTypeArray_.begin(), | ||||||||
| transitionTypeArray_.end(), | ||||||||
| [&](auto const& entry) { return entry.first == transitionTypeString; }); | ||||||||
| if (it != transitionTypeArray_.end()) { | ||||||||
| enabledTransitionTypes_.insert(it->second); | ||||||||
| } | ||||||||
| else { | ||||||||
| error() << "Extra transition type name '" << transitionTypeString << "' is not a valid identifier"; | ||||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| if (enabledTransitionTypes_.empty()) { | ||||||||
| error() << "Label scorer has no enabled transition types. Activate a preset and/or add extra transition types that should be considered."; | ||||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| } // namespace Nn | ||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -73,11 +73,14 @@ namespace Nn { | |
| */ | ||
| class LabelScorer : public virtual Core::Component, | ||
| public Core::ReferenceCounted { | ||
| static const Core::ParameterStringVector paramIgnoredTransitionTypes; | ||
|
|
||
| public: | ||
| typedef Search::Score Score; | ||
|
|
||
| static const Core::Choice choiceTransitionPreset; | ||
| static const Core::ParameterChoice paramTransitionPreset; | ||
|
|
||
| static const Core::ParameterStringVector paramExtraTransitionTypes; | ||
|
|
||
| enum TransitionType { | ||
| LABEL_TO_LABEL, | ||
| LABEL_LOOP, | ||
|
|
@@ -89,6 +92,15 @@ public: | |
| numTypes, // must remain at the end | ||
| }; | ||
|
|
||
| enum TransitionPresetType { | ||
| DEFAULT, | ||
| NONE, | ||
| ALL, | ||
| CTC, | ||
| TRANSDUCER, | ||
| LM, | ||
| }; | ||
|
|
||
| // Request for scoring or context extension | ||
| struct Request { | ||
| ScoringContextRef context; | ||
|
|
@@ -166,8 +178,14 @@ protected: | |
| // By default loops over the single-request version | ||
| virtual std::optional<ScoresWithTimes> computeScoresWithTimesInternal(std::vector<Request> const& requests); | ||
|
|
||
| virtual TransitionPresetType defaultPreset() const { | ||
| return TransitionPresetType::NONE; | ||
| } | ||
|
Comment on lines
+181
to
+183
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice to move method definitions out of the class. Can either be inline in the header or just in the .cc file. |
||
|
|
||
| private: | ||
| std::unordered_set<TransitionType> ignoredTransitionTypes_; | ||
| std::unordered_set<TransitionType> enabledTransitionTypes_; | ||
|
|
||
| void enableTransitionTypes(Core::Configuration const& config); | ||
| }; | ||
|
|
||
| } // namespace Nn | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -61,6 +61,10 @@ protected: | |
| // Uses `getScoresWithTimes` internally with some wrapping for vector packing/expansion | ||
| std::optional<LabelScorer::ScoreWithTime> computeScoreWithTimeInternal(LabelScorer::Request const& request) override; | ||
|
|
||
| virtual TransitionPresetType defaultPreset() const override { | ||
| return TransitionPresetType::CTC; | ||
| } | ||
|
Comment on lines
+64
to
+66
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see others |
||
|
|
||
| private: | ||
| Onnx::Model onnxModel_; | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,6 +44,10 @@ protected: | |
|
|
||
| // Gets the buffered score for the requested token at the requested step | ||
| std::optional<LabelScorer::ScoreWithTime> computeScoreWithTimeInternal(LabelScorer::Request const& request) override; | ||
|
|
||
| virtual TransitionPresetType defaultPreset() const override { | ||
| return TransitionPresetType::CTC; | ||
| } | ||
|
Comment on lines
+48
to
+50
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. method definition not in class declaration |
||
| }; | ||
|
|
||
| } // namespace Nn | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -86,6 +86,10 @@ protected: | |
| std::optional<LabelScorer::ScoreWithTime> computeScoreWithTimeInternal(LabelScorer::Request const& request) override; | ||
| std::optional<LabelScorer::ScoresWithTimes> computeScoresWithTimesInternal(std::vector<LabelScorer::Request> const& requests) override; | ||
|
|
||
| virtual TransitionPresetType defaultPreset() const override { | ||
| return TransitionPresetType::LM; | ||
| } | ||
|
Comment on lines
+89
to
+91
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move definition out of class declaration. |
||
|
|
||
| private: | ||
| // Forward a batch of histories through the ONNX model and put the resulting scores into the score cache | ||
| void forwardBatch(std::vector<OnnxHiddenStateScoringContextRef> const& historyBatch); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see others