Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/Nn/LabelScorer/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ LIBSPRINTLABELSCORER_O = \
$(OBJDIR)/NoContextOnnxLabelScorer.o \
$(OBJDIR)/NoOpLabelScorer.o \
$(OBJDIR)/ScoringContext.o \
$(OBJDIR)/StatefulOnnxLabelScorer.o
$(OBJDIR)/StatefulOnnxLabelScorer.o \
$(OBJDIR)/StatefulTransducerOnnxLabelScorer.o

# -----------------------------------------------------------------------------

Expand Down
32 changes: 32 additions & 0 deletions src/Nn/LabelScorer/ScoringContext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,36 @@ bool OnnxHiddenStateScoringContext::isEqual(ScoringContextRef const& other) cons
return true;
}

/*
* =====================================
* = StepOnnxHiddenStateScoringContext =
* =====================================
*/
size_t StepOnnxHiddenStateScoringContext::hash() const {
return Core::combineHashes(currentStep, Core::MurmurHash3_x64_64(reinterpret_cast<void const*>(labelSeq.data()), labelSeq.size() * sizeof(LabelIndex), 0x78b174eb));
}

bool StepOnnxHiddenStateScoringContext::isEqual(ScoringContextRef const& other) const {
auto* otherPtr = dynamic_cast<const StepOnnxHiddenStateScoringContext*>(other.get());
if (otherPtr == nullptr) {
return false;
}

if (currentStep != otherPtr->currentStep) {
return false;
}

if (labelSeq.size() != otherPtr->labelSeq.size()) {
return false;
}

for (auto it_l = labelSeq.begin(), it_r = otherPtr->labelSeq.begin(); it_l != labelSeq.end(); ++it_l, ++it_r) {
if (*it_l != *it_r) {
return false;
}
}

return true;
}

} // namespace Nn
23 changes: 23 additions & 0 deletions src/Nn/LabelScorer/ScoringContext.hh
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,29 @@ struct OnnxHiddenStateScoringContext : public ScoringContext {

typedef Core::Ref<const OnnxHiddenStateScoringContext> OnnxHiddenStateScoringContextRef;

/*
* Scoring context consisting of a hidden state and a step.
* Assumes that two hidden states are equal if and only if they were created
* from the same label history.
*/
struct StepOnnxHiddenStateScoringContext : public ScoringContext {
Speech::TimeframeIndex currentStep;
std::vector<LabelIndex> labelSeq; // Used for hashing
mutable OnnxHiddenStateRef hiddenState;
mutable bool requiresFinalize;

StepOnnxHiddenStateScoringContext()
: currentStep(0u), labelSeq(), hiddenState(), requiresFinalize(false) {}

StepOnnxHiddenStateScoringContext(Speech::TimeframeIndex step, std::vector<LabelIndex> const& labelSeq, OnnxHiddenStateRef state)
: currentStep(step), labelSeq(labelSeq), hiddenState(state), requiresFinalize(false) {}

bool isEqual(ScoringContextRef const& other) const;
size_t hash() const;
};

typedef Core::Ref<const StepOnnxHiddenStateScoringContext> StepOnnxHiddenStateScoringContextRef;

} // namespace Nn

#endif // SCORING_CONTEXT_HH
4 changes: 4 additions & 0 deletions src/Nn/LabelScorer/StatefulOnnxLabelScorer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ namespace Nn {
*
* A common use case for this Label Scorer would be an AED model with cross-attention over the encoder output.
* Since the encoder state inputs are optional, it can also be used for stateful language models without acoustic input.
*
* Note: This LabelScorer is similar to the `StatefulTransducerOnnxLabelScorer`. The difference is that in this it is assumed that the
* input features are processed into the hidden states and they are not directly fed into the scorer. For this, the state initializer
* and updater here also take input features in addition to tokens.
*/
class StatefulOnnxLabelScorer : public BufferedLabelScorer {
using Precursor = BufferedLabelScorer;
Expand Down
Loading