Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
7502b82
Add TransitionLabelScorer
SimBe195 Jul 24, 2025
7430001
Rewrite docstring
SimBe195 Jul 24, 2025
2a6272e
Clean up includes
SimBe195 Jul 24, 2025
7e325e1
Rewrite docstring again
SimBe195 Jul 24, 2025
a276136
Merge branch 'master' into tdp_label_scorer
SimBe195 Sep 24, 2025
d2d78fe
Refactor params to string list with compile time check
SimBe195 Sep 24, 2025
303fa46
Remove transitionTypeToIndex function and revert associated changes
SimBe195 Sep 24, 2025
ddd75c7
Revert unnecessary static_cast
SimBe195 Sep 24, 2025
b856c1e
Change std=c++17 to c++20
SimBe195 Sep 30, 2025
70699c0
Merge remote-tracking branch 'origin/version-bump' into tdp_label_scorer
SimBe195 Sep 30, 2025
5b89d0f
Move transition type string array to LabelScorer.hh
SimBe195 Sep 30, 2025
8b27b19
Add parameter for ignoring transition types in LabelScorer
SimBe195 Oct 1, 2025
1e877c7
Add missing parenthesis in description
SimBe195 Oct 1, 2025
3dcadee
Add some docstrings for the `Internal` functions
SimBe195 Oct 1, 2025
b9d919b
Move transitionTypeArray to protected space
SimBe195 Oct 1, 2025
23df463
Merge branch 'tdp_label_scorer' into disabled-transition-types
SimBe195 Oct 1, 2025
a437c04
Merge branch 'master' into disabled-transition-types
curufinwe Oct 8, 2025
8337fea
Fix order in .cc files
curufinwe Oct 9, 2025
0cfdf3d
Add `set` function to Core::CollapsedVector
SimBe195 Oct 10, 2025
85522b5
Apply suggestions from code review
SimBe195 Oct 10, 2025
4098ad4
Fix compilation
SimBe195 Oct 10, 2025
0d34085
Introduce configurable presets of enabled transition types + extras
SimBe195 Oct 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/Core/CollapsedVector.hh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public:
inline void push_back(const T& value);
inline const T& operator[](size_t idx) const;
inline const T& at(size_t idx) const;
inline void set(size_t idx, const T& value);
inline size_t size() const noexcept;
inline void clear() noexcept;
inline void reserve(size_t size);
Expand Down Expand Up @@ -105,6 +106,21 @@ inline const T& CollapsedVector<T>::at(size_t idx) const {
return (*this)[idx];
}

template<typename T>
inline void CollapsedVector<T>::set(size_t idx, const T& value) {
if (idx >= logicalSize_) {
throw std::out_of_range("Trying to access illegal index of CollapsedVector");
}
if (data_.size() != 1ul) {
data_[idx] = value;
data_.push_back(value);
}
else if (value != data_.front()) {
data_.resize(logicalSize_, data_.front());
data_[idx] = value;
}
}

template<typename T>
inline size_t CollapsedVector<T>::size() const noexcept {
return logicalSize_;
Expand Down
40 changes: 22 additions & 18 deletions src/Nn/LabelScorer/CombineLabelScorer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,6 @@ ScoringContextRef CombineLabelScorer::getInitialScoringContext() {
return Core::ref(new CombineScoringContext(std::move(scoringContexts)));
}

ScoringContextRef CombineLabelScorer::extendedScoringContext(Request const& request) {
auto combineContext = dynamic_cast<const CombineScoringContext*>(request.context.get());

std::vector<ScoringContextRef> extScoringContexts;
extScoringContexts.reserve(scaledScorers_.size());

auto scorerIt = scaledScorers_.begin();
auto contextIt = combineContext->scoringContexts.begin();

for (; scorerIt != scaledScorers_.end(); ++scorerIt, ++contextIt) {
Request subRequest{*contextIt, request.nextToken, request.transitionType};
extScoringContexts.push_back(scorerIt->scorer->extendedScoringContext(subRequest));
}
return Core::ref(new CombineScoringContext(std::move(extScoringContexts)));
}

void CombineLabelScorer::cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) {
std::vector<const CombineScoringContext*> combineContexts;
combineContexts.reserve(activeContexts.internalSize());
Expand Down Expand Up @@ -101,7 +85,23 @@ void CombineLabelScorer::addInputs(DataView const& input, size_t nTimesteps) {
}
}

std::optional<LabelScorer::ScoreWithTime> CombineLabelScorer::computeScoreWithTime(Request const& request) {
ScoringContextRef CombineLabelScorer::extendedScoringContextInternal(Request const& request) {
auto combineContext = dynamic_cast<const CombineScoringContext*>(request.context.get());

std::vector<ScoringContextRef> extScoringContexts;
extScoringContexts.reserve(scaledScorers_.size());

auto scorerIt = scaledScorers_.begin();
auto contextIt = combineContext->scoringContexts.begin();

for (; scorerIt != scaledScorers_.end(); ++scorerIt, ++contextIt) {
Request subRequest{*contextIt, request.nextToken, request.transitionType};
extScoringContexts.push_back(scorerIt->scorer->extendedScoringContext(subRequest));
}
return Core::ref(new CombineScoringContext(std::move(extScoringContexts)));
}

std::optional<LabelScorer::ScoreWithTime> CombineLabelScorer::computeScoreWithTimeInternal(Request const& request) {
// Initialize accumulated result with zero-valued score and timestep
ScoreWithTime accumResult{0.0, 0};

Expand Down Expand Up @@ -130,7 +130,11 @@ std::optional<LabelScorer::ScoreWithTime> CombineLabelScorer::computeScoreWithTi
return accumResult;
}

std::optional<LabelScorer::ScoresWithTimes> CombineLabelScorer::computeScoresWithTimes(std::vector<Request> const& requests) {
std::optional<LabelScorer::ScoresWithTimes> CombineLabelScorer::computeScoresWithTimesInternal(std::vector<Request> const& requests) {
if (requests.empty()) {
return ScoresWithTimes{};
}

// Initialize accumulated results with zero-valued scores and timesteps
ScoresWithTimes accumResult{std::vector<Score>(requests.size(), 0.0), {requests.size(), 0}};

Expand Down
23 changes: 14 additions & 9 deletions src/Nn/LabelScorer/CombineLabelScorer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ public:
// Combine initial ScoringContexts from all sub-scorers
ScoringContextRef getInitialScoringContext() override;

// Combine extended ScoringContexts from all sub-scorers
ScoringContextRef extendedScoringContext(Request const& request) override;

// Cleanup all sub-scorers
void cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) override;

Expand All @@ -58,19 +55,27 @@ public:
// Add inputs to all sub-scorers
virtual void addInputs(DataView const& input, size_t nTimesteps) override;

// Compute weighted score of request with all sub-scorers
std::optional<ScoreWithTime> computeScoreWithTime(Request const& request) override;

// Compute weighted scores of requests with all sub-scorers
std::optional<ScoresWithTimes> computeScoresWithTimes(std::vector<Request> const& requests) override;

protected:
struct ScaledLabelScorer {
Core::Ref<LabelScorer> scorer;
Score scale;
};

std::vector<ScaledLabelScorer> scaledScorers_;

// Combine extended ScoringContexts from all sub-scorers
ScoringContextRef extendedScoringContextInternal(Request const& request) override;

// Compute weighted score of request with all sub-scorers
std::optional<ScoreWithTime> computeScoreWithTimeInternal(Request const& request) override;

// Compute weighted scores of requests with all sub-scorers
std::optional<ScoresWithTimes> computeScoresWithTimesInternal(std::vector<Request> const& requests) override;

// Request filtering should happen in the sub-scorers, this one should let everything through
virtual TransitionPresetType defaultPreset() const override {
return TransitionPresetType::ALL;
}
};

} // namespace Nn
Expand Down
16 changes: 10 additions & 6 deletions src/Nn/LabelScorer/EncoderDecoderLabelScorer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ ScoringContextRef EncoderDecoderLabelScorer::getInitialScoringContext() {
return decoder_->getInitialScoringContext();
}

ScoringContextRef EncoderDecoderLabelScorer::extendedScoringContext(Request const& request) {
return decoder_->extendedScoringContext(request);
}

void EncoderDecoderLabelScorer::cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) {
decoder_->cleanupCaches(activeContexts);
}
Expand All @@ -59,11 +55,19 @@ void EncoderDecoderLabelScorer::signalNoMoreFeatures() {
decoder_->signalNoMoreFeatures();
}

std::optional<LabelScorer::ScoreWithTime> EncoderDecoderLabelScorer::computeScoreWithTime(LabelScorer::Request const& request) {
ScoringContextRef EncoderDecoderLabelScorer::extendedScoringContextInternal(Request const& request) {
return decoder_->extendedScoringContext(request);
}

std::optional<LabelScorer::ScoreWithTime> EncoderDecoderLabelScorer::computeScoreWithTimeInternal(LabelScorer::Request const& request) {
return decoder_->computeScoreWithTime(request);
}

std::optional<LabelScorer::ScoresWithTimes> EncoderDecoderLabelScorer::computeScoresWithTimes(std::vector<LabelScorer::Request> const& requests) {
std::optional<LabelScorer::ScoresWithTimes> EncoderDecoderLabelScorer::computeScoresWithTimesInternal(std::vector<LabelScorer::Request> const& requests) {
if (requests.empty()) {
return ScoresWithTimes{};
}

return decoder_->computeScoresWithTimes(requests);
}

Expand Down
16 changes: 11 additions & 5 deletions src/Nn/LabelScorer/EncoderDecoderLabelScorer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ public:
// Get start context from decoder component
ScoringContextRef getInitialScoringContext() override;

// Get extended context from decoder component
ScoringContextRef extendedScoringContext(Request const& request) override;

// Cleanup decoder component. Encoder is "self-cleaning" already in that it only stores outputs until they are
// retrieved.
void cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) override;
Expand All @@ -60,11 +57,20 @@ public:
// Same as `addInput` but adds features for multiple timesteps at once
void addInputs(DataView const& input, size_t nTimesteps) override;

protected:
// Get extended context from decoder component
ScoringContextRef extendedScoringContextInternal(Request const& request) override;

// Run request through decoder component
std::optional<LabelScorer::ScoreWithTime> computeScoreWithTime(LabelScorer::Request const& request) override;
std::optional<LabelScorer::ScoreWithTime> computeScoreWithTimeInternal(LabelScorer::Request const& request) override;

// Run requests through decoder component
std::optional<LabelScorer::ScoresWithTimes> computeScoresWithTimes(std::vector<LabelScorer::Request> const& requests) override;
std::optional<LabelScorer::ScoresWithTimes> computeScoresWithTimesInternal(std::vector<LabelScorer::Request> const& requests) override;

// Request filtering should happen in the decoder, this one should let everything through
virtual TransitionPresetType defaultPreset() const override {
return TransitionPresetType::ALL;
}

private:
Core::Ref<Encoder> encoder_;
Expand Down
56 changes: 30 additions & 26 deletions src/Nn/LabelScorer/FixedContextOnnxLabelScorer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,32 @@ ScoringContextRef FixedContextOnnxLabelScorer::getInitialScoringContext() {
return hist;
}

ScoringContextRef FixedContextOnnxLabelScorer::extendedScoringContext(LabelScorer::Request const& request) {
size_t FixedContextOnnxLabelScorer::getMinActiveInputIndex(Core::CollapsedVector<ScoringContextRef> const& activeContexts) const {
auto minTimeIndex = Core::Type<Speech::TimeframeIndex>::max;
for (auto const& context : activeContexts.internalData()) {
SeqStepScoringContextRef stepHistory(dynamic_cast<const SeqStepScoringContext*>(context.get()));
minTimeIndex = std::min(minTimeIndex, stepHistory->currentStep);
}

return minTimeIndex;
}

void FixedContextOnnxLabelScorer::cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) {
Precursor::cleanupCaches(activeContexts);

std::unordered_set<ScoringContextRef, ScoringContextHash, ScoringContextEq> activeContextSet(activeContexts.internalData().begin(), activeContexts.internalData().end());

for (auto it = scoreCache_.begin(); it != scoreCache_.end();) {
if (activeContextSet.find(it->first) == activeContextSet.end()) {
it = scoreCache_.erase(it);
}
else {
++it;
}
}
}

ScoringContextRef FixedContextOnnxLabelScorer::extendedScoringContextInternal(LabelScorer::Request const& request) {
SeqStepScoringContextRef context(dynamic_cast<const SeqStepScoringContext*>(request.context.get()));

bool pushToken = false;
Expand Down Expand Up @@ -144,22 +169,11 @@ ScoringContextRef FixedContextOnnxLabelScorer::extendedScoringContext(LabelScore
return Core::ref(new SeqStepScoringContext(std::move(newLabelSeq), context->currentStep + timeIncrement));
}

void FixedContextOnnxLabelScorer::cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) {
Precursor::cleanupCaches(activeContexts);

std::unordered_set<ScoringContextRef, ScoringContextHash, ScoringContextEq> activeContextSet(activeContexts.internalData().begin(), activeContexts.internalData().end());

for (auto it = scoreCache_.begin(); it != scoreCache_.end();) {
if (activeContextSet.find(it->first) == activeContextSet.end()) {
it = scoreCache_.erase(it);
}
else {
++it;
}
std::optional<LabelScorer::ScoresWithTimes> FixedContextOnnxLabelScorer::computeScoresWithTimesInternal(std::vector<LabelScorer::Request> const& requests) {
if (requests.empty()) {
return ScoresWithTimes{};
}
}

std::optional<LabelScorer::ScoresWithTimes> FixedContextOnnxLabelScorer::computeScoresWithTimes(std::vector<LabelScorer::Request> const& requests) {
ScoresWithTimes result;
result.scores.reserve(requests.size());

Expand Down Expand Up @@ -232,24 +246,14 @@ std::optional<LabelScorer::ScoresWithTimes> FixedContextOnnxLabelScorer::compute
return result;
}

std::optional<LabelScorer::ScoreWithTime> FixedContextOnnxLabelScorer::computeScoreWithTime(LabelScorer::Request const& request) {
std::optional<LabelScorer::ScoreWithTime> FixedContextOnnxLabelScorer::computeScoreWithTimeInternal(LabelScorer::Request const& request) {
auto result = computeScoresWithTimes({request});
if (not result.has_value()) {
return {};
}
return ScoreWithTime{result->scores.front(), result->timeframes.front()};
}

size_t FixedContextOnnxLabelScorer::getMinActiveInputIndex(Core::CollapsedVector<ScoringContextRef> const& activeContexts) const {
auto minTimeIndex = Core::Type<Speech::TimeframeIndex>::max;
for (auto const& context : activeContexts.internalData()) {
SeqStepScoringContextRef stepHistory(dynamic_cast<const SeqStepScoringContext*>(context.get()));
minTimeIndex = std::min(minTimeIndex, stepHistory->currentStep);
}

return minTimeIndex;
}

void FixedContextOnnxLabelScorer::forwardBatch(std::vector<SeqStepScoringContextRef> const& contextBatch) {
if (contextBatch.empty()) {
return;
Expand Down
20 changes: 12 additions & 8 deletions src/Nn/LabelScorer/FixedContextOnnxLabelScorer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -47,24 +47,28 @@ public:
// Initial scoring context contains step 0 and a history vector filled with the start label index
ScoringContextRef getInitialScoringContext() override;

// Clean up input buffer as well as cached score vectors that are no longer needed
void cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) override;

protected:
size_t getMinActiveInputIndex(Core::CollapsedVector<ScoringContextRef> const& activeContexts) const override;

// May increment the step by 1 (except for vertical transitions) and may append the next token to the
// history label sequence depending on the transition type and whether loops/blanks update the history
// or not
ScoringContextRef extendedScoringContext(LabelScorer::Request const& request) override;

// Clean up input buffer as well as cached score vectors that are no longer needed
void cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) override;
ScoringContextRef extendedScoringContextInternal(LabelScorer::Request const& request) override;

// If scores for the given scoring contexts are not yet cached, prepare and run an ONNX session to
// compute the scores and cache them
// Then, retreive scores from cache
std::optional<LabelScorer::ScoresWithTimes> computeScoresWithTimes(std::vector<LabelScorer::Request> const& requests) override;
std::optional<LabelScorer::ScoresWithTimes> computeScoresWithTimesInternal(std::vector<LabelScorer::Request> const& requests) override;

// Uses `getScoresWithTimes` internally with some wrapping for vector packing/expansion
std::optional<LabelScorer::ScoreWithTime> computeScoreWithTime(LabelScorer::Request const& request) override;
std::optional<LabelScorer::ScoreWithTime> computeScoreWithTimeInternal(LabelScorer::Request const& request) override;

protected:
size_t getMinActiveInputIndex(Core::CollapsedVector<ScoringContextRef> const& activeContexts) const override;
virtual TransitionPresetType defaultPreset() const override {
return TransitionPresetType::TRANSDUCER;
}

private:
// Forward a batch of histories through the ONNX model and put the resulting scores into the score cache
Expand Down
Loading