Skip to content
4 changes: 3 additions & 1 deletion src/Search/LanguageModelLookahead.cc
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,8 @@ void LanguageModelLookahead::ConstructionTree::build(HMMStateNetwork const&

for (HMMStateNetwork::SuccessorIterator target = tree_.successors(node); target; ++target) {
if (not target.isLabel()) {
if (*target == node)
continue;
build(*target, depth + 1);
successors.push_back(*target);
}
Expand Down Expand Up @@ -743,7 +745,7 @@ void LanguageModelLookahead::ConstructionTree::build(HMMStateNetwork const&
collected[node] = -2;

for (HMMStateNetwork::SuccessorIterator edges = tree_.successors(node); edges; ++edges) {
if (not edges.isLabel()) {
if (not edges.isLabel() and *edges != node) {
int depth2 = collectTopologicalStates(*edges, depth + 1, topologicalStates, collected);
if (depth2 - 1 < depth) {
depth = depth2 - 1;
Expand Down
130 changes: 128 additions & 2 deletions src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include <Core/CollapsedVector.hh>
#include <Core/XmlStream.hh>
#include <Lattice/LatticeAdaptor.hh>
#include <Lm/BackingOff.hh>
#include <Lm/Module.hh>
#include <Nn/LabelScorer/LabelScorer.hh>
#include <Nn/LabelScorer/ScoringContext.hh>
#include "Search/Module.hh"
Expand All @@ -38,8 +40,12 @@ TreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis()
: scoringContext(),
currentToken(Nn::invalidLabelIndex),
currentState(invalidTreeNodeIndex),
lookahead(),
lmHistory(),
lookaheadHistory(),
fullLookaheadHistory(),
score(0.0),
lookaheadScore(0.0),
trace(Core::ref(new LatticeTrace(0, {0, 0}, {}))) {}

TreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis(
Expand All @@ -49,8 +55,12 @@ TreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis(
: scoringContext(newScoringContext),
currentToken(extension.nextToken),
currentState(extension.state),
lookahead(extension.lookahead),
lmHistory(extension.lmHistory),
lookaheadHistory(extension.lookaheadHistory),
fullLookaheadHistory(extension.fullLookaheadHistory),
score(extension.score),
lookaheadScore(extension.lmScore),
trace(base.trace) {
if (extension.pron != nullptr) { // Word-end hypothesis -> update base trace and start a new trace for the next word
auto completedTrace = Core::ref(new LatticeTrace(*base.trace));
Expand Down Expand Up @@ -115,6 +125,21 @@ const Core::ParameterBool TreeTimesyncBeamSearch::paramCollapseRepeatedLabels(
"Collapse repeated emission of the same label into one output. If false, every emission is treated like a new output.",
false);

const Core::ParameterBool TreeTimesyncBeamSearch::paramLmLookahead(
"lm-lookahead",
"Enable language model lookahead.",
false);

const Core::ParameterBool TreeTimesyncBeamSearch::paramSeparateLookaheadLm(
"separate-lookahead-lm",
"Use a separate LM for lookahead.",
false);

const Core::ParameterBool TreeTimesyncBeamSearch::paramSparseLmLookAhead(
"sparse-lm-lookahead",
"Use sparse n-gram LM lookahead.",
true);

const Core::ParameterBool TreeTimesyncBeamSearch::paramSentenceEndFallBack(
"sentence-end-fall-back",
"Allow for fallback solution if no active word-end hypothesis exists at the end of a segment.",
Expand All @@ -140,6 +165,9 @@ TreeTimesyncBeamSearch::TreeTimesyncBeamSearch(Core::Configuration const& config
cacheCleanupInterval_(paramCacheCleanupInterval(config)),
useBlank_(),
collapseRepeatedLabels_(paramCollapseRepeatedLabels(config)),
enableLmLookahead_(paramLmLookahead(config)),
separateLookaheadLm_(paramSeparateLookaheadLm(config)),
sparseLmLookahead_(paramSparseLmLookAhead(config)),
sentenceEndFallback_(paramSentenceEndFallBack(config)),
logStepwiseStatistics_(paramLogStepwiseStatistics(config)),
labelScorer_(),
Expand Down Expand Up @@ -215,6 +243,34 @@ bool TreeTimesyncBeamSearch::setModelCombination(Speech::ModelCombination const&
// Create look-ups for state successors and exits of each state
createSuccessorLookups();

// Set lookahead LM
if (enableLmLookahead_) {
if (separateLookaheadLm_) {
log() << "Use separate lookahead LM";
lookaheadLm_ = Lm::Module::instance().createScaledLanguageModel(select("lm-lookahead"), lexicon_);
}
else if (languageModel_->lookaheadLanguageModel().get() != nullptr) {
lookaheadLm_ = Core::Ref<Lm::ScaledLanguageModel>(new Lm::LanguageModelScaling(select("lookahead-lm"),
Core::Ref<Lm::LanguageModel>(const_cast<Lm::LanguageModel*>(languageModel_->lookaheadLanguageModel().get()))));
}
else {
lookaheadLm_ = languageModel_;
}

if (sparseLmLookahead_ && !dynamic_cast<const Lm::BackingOffLm*>(lookaheadLm_->unscaled().get())) {
warning() << "Not using sparse LM lookahead, because the LM is not a backing-off LM.";
sparseLmLookahead_ = false;
}

lmLookahead_ = new LanguageModelLookahead(Core::Configuration(config, "lm-lookahead"),
modelCombination.pronunciationScale(),
lookaheadLm_,
network_->structure,
network_->rootState,
network_->exits,
acousticModel_);
}

reset();

// Create global cache
Expand All @@ -240,6 +296,11 @@ void TreeTimesyncBeamSearch::reset() {
beam_.front().currentState = network_->rootState;
beam_.front().lmHistory = languageModel_->startHistory();

if (enableLmLookahead_) {
beam_.front().lookaheadHistory = lookaheadLm_->startHistory();
beam_.front().fullLookaheadHistory = lookaheadLm_->startHistory();
}

currentSearchStep_ = 0ul;
finishedSegment_ = false;

Expand Down Expand Up @@ -330,7 +391,10 @@ bool TreeTimesyncBeamSearch::decodeStep() {
{tokenIdx,
nullptr,
successorState,
hyp.lookahead,
hyp.lmHistory,
hyp.lookaheadHistory,
hyp.fullLookaheadHistory,
hyp.score,
0.0,
0,
Expand All @@ -355,6 +419,14 @@ bool TreeTimesyncBeamSearch::decodeStep() {
for (size_t requestIdx = 0ul; requestIdx < extensions_.size(); ++requestIdx) {
extensions_[requestIdx].score += result->scores[requestIdx];
extensions_[requestIdx].timeframe = result->timeframes[requestIdx];

// Add the LM lookahead score to the extensions' scores for pruning
// Make sure not to calculate the lookahead score for the blank lemma which is reachable from the root
if (enableLmLookahead_ and not(beam_[extensions_[requestIdx].baseHypIndex].currentState == network_->rootState and extensions_[requestIdx].nextToken == blankLabelIndex_)) {
Score lookaheadScore = getLmLookaheadScore(extensions_[requestIdx]);
extensions_[requestIdx].lmScore = lookaheadScore;
extensions_[requestIdx].score += lookaheadScore;
}
}

if (logStepwiseStatistics_) {
Expand Down Expand Up @@ -404,6 +476,12 @@ bool TreeTimesyncBeamSearch::decodeStep() {
for (size_t hypIndex = 0ul; hypIndex < newBeam_.size(); ++hypIndex) {
auto& hyp = newBeam_[hypIndex];

if (enableLmLookahead_) {
// Subtract the LM lookahead score again
hyp.score -= hyp.lookaheadScore;
hyp.lookaheadScore = 0;
}

std::vector<PersistentStateTree::Exit> exitList = exitLookup_[hyp.currentState];
if (not exitList.empty()) {
// Create one word-end hypothesis for each exit
Expand All @@ -414,7 +492,10 @@ bool TreeTimesyncBeamSearch::decodeStep() {
ExtensionCandidate wordEndExtension{hyp.currentToken,
lemmaPron,
exit.transitState, // Start from the root node (the exit's transit state) in the next step
hyp.lookahead,
hyp.lmHistory,
hyp.lookaheadHistory,
hyp.fullLookaheadHistory,
hyp.score,
0.0,
static_cast<TimeframeIndex>(currentSearchStep_),
Expand Down Expand Up @@ -444,14 +525,24 @@ bool TreeTimesyncBeamSearch::decodeStep() {
clog() << Core::XmlFull("num-word-end-hyps-after-score-pruning", extensions_.size());
}

// Create new word-end label hypotheses from word-end extension candidates and update the LM history
// Create new word-end label hypotheses from word-end extension candidates, update the LM history and prepare the new lookahead if its history has changed
wordEndHypotheses_.clear();
for (auto& extension : extensions_) {
const Bliss::Lemma* lemma = extension.pron->lemma();
if (lemma != lexicon_->specialLemma("blank") and lemma != lexicon_->specialLemma("silence")) {
const Bliss::SyntacticTokenSequence sts = lemma->syntacticTokenSequence();
const Bliss::SyntacticToken* st = sts.front();
extension.lmHistory = languageModel_->extendedHistory(extension.lmHistory, st);

if (enableLmLookahead_) {
Lm::History newLookaheadHistory = lookaheadLm_->extendedHistory(extension.fullLookaheadHistory, st);

if (!(newLookaheadHistory == extension.lookaheadHistory)) {
getLmLookahead(extension.lookahead, newLookaheadHistory);
extension.lookaheadHistory = newLookaheadHistory;
extension.fullLookaheadHistory = newLookaheadHistory;
}
}
}

auto const& baseHyp = newBeam_[extension.baseHypIndex];
Expand Down Expand Up @@ -689,6 +780,41 @@ void TreeTimesyncBeamSearch::recombination(std::vector<TreeTimesyncBeamSearch::L
hypotheses.swap(recombinedHypotheses_);
}

void TreeTimesyncBeamSearch::getLmLookahead(LanguageModelLookahead::ContextLookaheadReference& lookahead, Lm::History history) {
lookahead = lmLookahead_->getLookahead(history);
lmLookahead_->fill(lookahead, sparseLmLookahead_);
}

Score TreeTimesyncBeamSearch::getLmLookaheadScore(TreeTimesyncBeamSearch::ExtensionCandidate& extension) {
if (!extension.lookahead) {
getLmLookahead(extension.lookahead, extension.lookaheadHistory);
}

Score lookaheadScore = 0;
bool scoreFound = false;
do {
if (extension.lookahead->isSparse()) { // Non-sparse lookahead
auto lookaheadHash = lmLookahead_->lookaheadHash(extension.state);
scoreFound = extension.lookahead->getScoreForLookAheadHashSparse(lookaheadHash, lookaheadScore);
}
else { // Sparse lookahead
auto lookaheadId = lmLookahead_->lookaheadId(extension.state);
lookaheadScore = extension.lookahead->scoreForLookAheadIdNormal(lookaheadId);
scoreFound = true;
}

if (!scoreFound) { // No lookahead table entry, use back-off
const Lm::BackingOffLm* lm = dynamic_cast<const Lm::BackingOffLm*>(lookaheadLm_->unscaled().get());
lookaheadScore += lm->getBackOffScore(extension.lookaheadHistory);
// Reduce the history and retrieve the corresponding lookahead table
extension.lookaheadHistory = lm->reducedHistory(extension.lookaheadHistory, lm->historyLength(extension.lookaheadHistory) - 1);
getLmLookahead(extension.lookahead, extension.lookaheadHistory);
}
} while (!scoreFound);

return lookaheadScore;
}

void TreeTimesyncBeamSearch::createSuccessorLookups() {
stateSuccessorLookup_.resize(network_->structure.stateCount());
exitLookup_.resize(network_->structure.stateCount());
Expand Down Expand Up @@ -746,4 +872,4 @@ void TreeTimesyncBeamSearch::finalizeLmScoring() {
beam_.swap(newBeam_);
}

} // namespace Search
} // namespace Search
Loading