Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
7502b82
Add TransitionLabelScorer
SimBe195 Jul 24, 2025
7430001
Rewrite docstring
SimBe195 Jul 24, 2025
2a6272e
Clean up includes
SimBe195 Jul 24, 2025
7e325e1
Rewrite docstring again
SimBe195 Jul 24, 2025
a276136
Merge branch 'master' into tdp_label_scorer
SimBe195 Sep 24, 2025
d2d78fe
Refactor params to string list with compile time check
SimBe195 Sep 24, 2025
303fa46
Remove transitionTypeToIndex function and revert associated changes
SimBe195 Sep 24, 2025
ddd75c7
Revert unnecessary static_cast
SimBe195 Sep 24, 2025
b856c1e
Change std=c++17 to c++20
SimBe195 Sep 30, 2025
70699c0
Merge remote-tracking branch 'origin/version-bump' into tdp_label_scorer
SimBe195 Sep 30, 2025
5b89d0f
Move transition type string array to LabelScorer.hh
SimBe195 Sep 30, 2025
8b27b19
Add parameter for ignoring transition types in LabelScorer
SimBe195 Oct 1, 2025
1e877c7
Add missing parenthesis in description
SimBe195 Oct 1, 2025
3dcadee
Add some docstrings for the `Internal` functions
SimBe195 Oct 1, 2025
b9d919b
Move transitionTypeArray to protected space
SimBe195 Oct 1, 2025
23df463
Merge branch 'tdp_label_scorer' into disabled-transition-types
SimBe195 Oct 1, 2025
a437c04
Merge branch 'master' into disabled-transition-types
curufinwe Oct 8, 2025
8337fea
Fix order in .cc files
curufinwe Oct 9, 2025
0cfdf3d
Add `set` function to Core::CollapsedVector
SimBe195 Oct 10, 2025
85522b5
Apply suggestions from code review
SimBe195 Oct 10, 2025
4098ad4
Fix compilation
SimBe195 Oct 10, 2025
0d34085
Introduce configurable presets of enabled transition types + extras
SimBe195 Oct 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/Nn/LabelScorer/CombineLabelScorer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ protected:

// Compute weighted scores of requests with all sub-scorers
std::optional<ScoresWithTimes> computeScoresWithTimesInternal(std::vector<Request> const& requests) override;

// Request filtering should happen in the sub-scorers, this one should let everything through
virtual TransitionPresetType defaultPreset() const override {
return TransitionPresetType::ALL;
}
Comment on lines +76 to +78
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see others

};

} // namespace Nn
Expand Down
5 changes: 5 additions & 0 deletions src/Nn/LabelScorer/EncoderDecoderLabelScorer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ protected:
// Run requests through decoder component
std::optional<LabelScorer::ScoresWithTimes> computeScoresWithTimesInternal(std::vector<LabelScorer::Request> const& requests) override;

// Request filtering should happen in the decoder, this one should let everything through
virtual TransitionPresetType defaultPreset() const override {
return TransitionPresetType::ALL;
}
Comment on lines +71 to +73
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see others


private:
Core::Ref<Encoder> encoder_;
Core::Ref<LabelScorer> decoder_;
Expand Down
4 changes: 2 additions & 2 deletions src/Nn/LabelScorer/FixedContextOnnxLabelScorer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ ScoringContextRef FixedContextOnnxLabelScorer::getInitialScoringContext() {
return hist;
}

void FixedContextOnnxLabelScorer::size_t FixedContextOnnxLabelScorer::getMinActiveInputIndex(Core::CollapsedVector<ScoringContextRef> const& activeContexts) const {
size_t FixedContextOnnxLabelScorer::getMinActiveInputIndex(Core::CollapsedVector<ScoringContextRef> const& activeContexts) const {
auto minTimeIndex = Core::Type<Speech::TimeframeIndex>::max;
for (auto const& context : activeContexts.internalData()) {
SeqStepScoringContextRef stepHistory(dynamic_cast<const SeqStepScoringContext*>(context.get()));
Expand All @@ -107,7 +107,7 @@ void FixedContextOnnxLabelScorer::size_t FixedContextOnnxLabelScorer::getMinActi
return minTimeIndex;
}

cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) {
void FixedContextOnnxLabelScorer::cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) {
Precursor::cleanupCaches(activeContexts);

std::unordered_set<ScoringContextRef, ScoringContextHash, ScoringContextEq> activeContextSet(activeContexts.internalData().begin(), activeContexts.internalData().end());
Expand Down
4 changes: 4 additions & 0 deletions src/Nn/LabelScorer/FixedContextOnnxLabelScorer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ protected:
// Uses `getScoresWithTimes` internally with some wrapping for vector packing/expansion
std::optional<LabelScorer::ScoreWithTime> computeScoreWithTimeInternal(LabelScorer::Request const& request) override;

virtual TransitionPresetType defaultPreset() const override {
return TransitionPresetType::TRANSDUCER;
}

private:
// Forward a batch of histories through the ONNX model and put the resulting scores into the score cache
// Assumes that all histories in the batch are based on the same timestep
Expand Down
120 changes: 93 additions & 27 deletions src/Nn/LabelScorer/LabelScorer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,29 @@ namespace Nn {
* =============================
*/

const Core::ParameterStringVector LabelScorer::paramIgnoredTransitionTypes(
"ignored-transition-types",
"Transition types that should be ignored by the label scorer (i.e. get assigned score 0 and do not affect the ScoringContext)",
const Core::Choice LabelScorer::choiceTransitionPreset(
"default", TransitionPresetType::DEFAULT,
"none", TransitionPresetType::NONE,
"ctc", TransitionPresetType::CTC,
"transducer", TransitionPresetType::TRANSDUCER,
"lm", TransitionPresetType::LM,
Core::Choice::endMark());

const Core::ParameterChoice LabelScorer::paramTransitionPreset(
"transition-preset",
&LabelScorer::choiceTransitionPreset,
"Preset for which transition types should be enabled for the label scorer. Disabled transition types get assigned score 0 and do not affect the ScoringContext.",
TransitionPresetType::DEFAULT);

const Core::ParameterStringVector LabelScorer::paramExtraTransitionTypes(
"extra-transition-types",
"Transition types that should be enabled in addition to the ones given by the preset.",
",");

LabelScorer::LabelScorer(const Core::Configuration& config)
: Core::Component(config),
ignoredTransitionTypes_() {
auto ignoredTransitionTypeStrings = paramIgnoredTransitionTypes(config);
for (auto const& transitionTypeString : ignoredTransitionTypeStrings) {
auto it = std::find_if(transitionTypeArray_.begin(),
transitionTypeArray_.end(),
[&](auto const& entry) { return entry.first == transitionTypeString; });
if (it != transitionTypeArray_.end()) {
ignoredTransitionTypes_.insert(it->second);
}
else {
error() << "Ignored transition type name '" << transitionTypeString << "' is not a valid identifier";
}
}
enabledTransitionTypes_() {
enableTransitionTypes(config);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I have to withdraw may approve. I just found a bug.
Here, you are calling enableTransitionTypes(config) and in enableTransitionTypes() the function defaultPreset() is called. As this is done in the constructor of LabelScorer, the defaultPreset() function of LabelScorer will always be called instead of the implementation of the derived classes one is using.
As a fix, I suggest to put enableTransitionTypes(config) to the constructor of every derived class and remove it from the base class constructor here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eugen: yes pelase

}

void LabelScorer::addInputs(DataView const& input, size_t nTimesteps) {
Expand All @@ -53,17 +56,17 @@ void LabelScorer::addInputs(DataView const& input, size_t nTimesteps) {
}

ScoringContextRef LabelScorer::extendedScoringContext(Request const& request) {
if (ignoredTransitionTypes_.contains(request.transitionType)) {
return request.context;
if (enabledTransitionTypes_.contains(request.transitionType)) {
return extendedScoringContextInternal(request);
}
return extendedScoringContextInternal(request);
return request.context;
}

std::optional<LabelScorer::ScoreWithTime> LabelScorer::computeScoreWithTime(Request const& request) {
if (ignoredTransitionTypes_.contains(request.transitionType)) {
return ScoreWithTime{0.0, 0};
if (enabledTransitionTypes_.contains(request.transitionType)) {
return computeScoreWithTimeInternal(request);
}
return computeScoreWithTimeInternal(request);
return ScoreWithTime{0.0, 0};
}

std::optional<LabelScorer::ScoresWithTimes> LabelScorer::computeScoresWithTimes(std::vector<LabelScorer::Request> const& requests) {
Expand All @@ -76,7 +79,7 @@ std::optional<LabelScorer::ScoresWithTimes> LabelScorer::computeScoresWithTimes(

for (size_t requestIndex = 0ul; requestIndex < requests.size(); ++requestIndex) {
auto const& request = requests[requestIndex];
if (not ignoredTransitionTypes_.contains(request.transitionType)) {
if (enabledTransitionTypes_.contains(request.transitionType)) {
nonIgnoredRequests.push_back(request);
nonIgnoredRequestIndices.push_back(requestIndex);
}
Expand All @@ -89,12 +92,13 @@ std::optional<LabelScorer::ScoresWithTimes> LabelScorer::computeScoresWithTimes(
}

// Interleave actual results with 0 scores for requests with ignored transition types
ScoresWithTimes result{{requests.size(), 0.0}, {requests.size(), 0}};
ScoresWithTimes result{
.scores = std::vector<Score>(requests.size(), 0.0),
.timeframes{requests.size(), 0}};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.timeframes{requests.size(), 0}};
.timeframes{requests.size(), 0}
};

for (size_t i = 0ul; i < nonIgnoredRequestIndices.size(); ++i) {
auto requestResult = nonIgnoredResults[i];
auto requestIndex = nonIgnoredRequestIndices[i];
result.scores[requestIndex] = requestResult.score;
result.timeframes.set(requestIndex, requestResult.timeframe);
result.scores[requestIndex] = nonIgnoredResults->scores[i];
result.timeframes.set(requestIndex, nonIgnoredResults->timeframes[i]);
}

return result;
Expand Down Expand Up @@ -122,4 +126,66 @@ std::optional<LabelScorer::ScoresWithTimes> LabelScorer::computeScoresWithTimesI
return result;
}

void LabelScorer::enableTransitionTypes(Core::Configuration const& config) {
auto preset = paramTransitionPreset(config);
if (preset == TransitionPresetType::DEFAULT) {
preset = defaultPreset();
}
verify(preset != TransitionPresetType::DEFAULT);

switch (preset) {
case TransitionPresetType::NONE:
break;
case TransitionPresetType::ALL:
for (auto const& [_, transitionType] : transitionTypeArray_) {
enabledTransitionTypes_.insert(transitionType);
}
break;
case TransitionPresetType::CTC:
enabledTransitionTypes_ = {
LABEL_TO_LABEL,
LABEL_LOOP,
LABEL_TO_BLANK,
BLANK_TO_LABEL,
BLANK_LOOP,
INITIAL_LABEL,
INITIAL_BLANK,
};
break;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's related to this PR #152, but TRANSDUCER and LM need additonally SENTENCE_END transition types.

case TransitionPresetType::TRANSDUCER:
enabledTransitionTypes_ = {
LABEL_TO_LABEL,
LABEL_TO_BLANK,
BLANK_TO_LABEL,
BLANK_LOOP,
INITIAL_LABEL,
INITIAL_BLANK,
};
break;
case TransitionPresetType::LM:
enabledTransitionTypes_ = {
LABEL_TO_LABEL,
INITIAL_LABEL,
Comment on lines +165 to +168
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The AedTreeBuilder and labelsync searches are still WIP and not merged yet, but at some point we will also need a preset for AED and I guess this will be the same as this one. So will we then just add a preset which is exactly the same, just with a different name? And if yes, should LM or AED be the default preset of the StatefulOnnxLabelScorer? I mean in the end it's the same, but it might become confusing because of the naming.

};
break;
}

auto extraTransitionTypeStrings = paramExtraTransitionTypes(config);
for (auto const& transitionTypeString : extraTransitionTypeStrings) {
auto it = std::find_if(transitionTypeArray_.begin(),
transitionTypeArray_.end(),
[&](auto const& entry) { return entry.first == transitionTypeString; });
if (it != transitionTypeArray_.end()) {
enabledTransitionTypes_.insert(it->second);
}
else {
error() << "Extra transition type name '" << transitionTypeString << "' is not a valid identifier";
}
}

if (enabledTransitionTypes_.empty()) {
error() << "Label scorer has no enabled transition types. Activate a preset and/or add extra transition types that should be considered.";
}
}

} // namespace Nn
24 changes: 21 additions & 3 deletions src/Nn/LabelScorer/LabelScorer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,14 @@ namespace Nn {
*/
class LabelScorer : public virtual Core::Component,
public Core::ReferenceCounted {
static const Core::ParameterStringVector paramIgnoredTransitionTypes;

public:
typedef Search::Score Score;

static const Core::Choice choiceTransitionPreset;
static const Core::ParameterChoice paramTransitionPreset;

static const Core::ParameterStringVector paramExtraTransitionTypes;

enum TransitionType {
LABEL_TO_LABEL,
LABEL_LOOP,
Expand All @@ -89,6 +92,15 @@ public:
numTypes, // must remain at the end
};

enum TransitionPresetType {
DEFAULT,
NONE,
ALL,
CTC,
TRANSDUCER,
LM,
};

// Request for scoring or context extension
struct Request {
ScoringContextRef context;
Expand Down Expand Up @@ -166,8 +178,14 @@ protected:
// By default loops over the single-request version
virtual std::optional<ScoresWithTimes> computeScoresWithTimesInternal(std::vector<Request> const& requests);

virtual TransitionPresetType defaultPreset() const {
return TransitionPresetType::NONE;
}
Comment on lines +181 to +183
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to move method definitions out of the class. Can either be inline in the header or just in the .cc file.


private:
std::unordered_set<TransitionType> ignoredTransitionTypes_;
std::unordered_set<TransitionType> enabledTransitionTypes_;

void enableTransitionTypes(Core::Configuration const& config);
};

} // namespace Nn
Expand Down
4 changes: 4 additions & 0 deletions src/Nn/LabelScorer/NoContextOnnxLabelScorer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ protected:
// Uses `getScoresWithTimes` internally with some wrapping for vector packing/expansion
std::optional<LabelScorer::ScoreWithTime> computeScoreWithTimeInternal(LabelScorer::Request const& request) override;

virtual TransitionPresetType defaultPreset() const override {
return TransitionPresetType::CTC;
}
Comment on lines +64 to +66
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see others


private:
Onnx::Model onnxModel_;

Expand Down
4 changes: 4 additions & 0 deletions src/Nn/LabelScorer/NoOpLabelScorer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ protected:

// Gets the buffered score for the requested token at the requested step
std::optional<LabelScorer::ScoreWithTime> computeScoreWithTimeInternal(LabelScorer::Request const& request) override;

virtual TransitionPresetType defaultPreset() const override {
return TransitionPresetType::CTC;
}
Comment on lines +48 to +50
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

method definition not in class declaration

};

} // namespace Nn
Expand Down
4 changes: 4 additions & 0 deletions src/Nn/LabelScorer/StatefulOnnxLabelScorer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ protected:
std::optional<LabelScorer::ScoreWithTime> computeScoreWithTimeInternal(LabelScorer::Request const& request) override;
std::optional<LabelScorer::ScoresWithTimes> computeScoresWithTimesInternal(std::vector<LabelScorer::Request> const& requests) override;

virtual TransitionPresetType defaultPreset() const override {
return TransitionPresetType::LM;
}
Comment on lines +89 to +91
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move definition out of class declaration.


private:
// Forward a batch of histories through the ONNX model and put the resulting scores into the score cache
void forwardBatch(std::vector<OnnxHiddenStateScoringContextRef> const& historyBatch);
Expand Down
4 changes: 4 additions & 0 deletions src/Nn/LabelScorer/TransitionLabelScorer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ protected:
// Compute scores of base scorer and add transition scores based on transition types of the requests
std::optional<ScoresWithTimes> computeScoresWithTimesInternal(std::vector<Request> const& requests) override;

virtual TransitionPresetType defaultPreset() const override {
return TransitionPresetType::ALL;
}

private:
std::unordered_map<TransitionType, Score> transitionScores_;

Expand Down