Skip to content

Diverse beam search (CPU) #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 92 additions & 72 deletions flashlight/lib/text/decoder/LexiconFreeSeq2SeqDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ void LexiconFreeSeq2SeqDecoder::decodeStep(
hyp_[0].clear();
hyp_[0].emplace_back(0.0, lm_->start(0), nullptr, -1, nullptr);

// Size of each group
int grpSize = opt_.beamSize / opt_.numBeamGroups;

// Decode frame by frame
int t = 0;
for (; t < maxOutputLength_; t++) {
Expand Down Expand Up @@ -61,87 +64,104 @@ void LexiconFreeSeq2SeqDecoder::decodeStep(

std::vector<size_t> idx(emittingModelScores.back().size());

// Generate new hypothesis
for (int hypo = 0, validHypo = 0; hypo < hyp_[t].size(); hypo++) {
const LexiconFreeSeq2SeqDecoderState& prevHyp = hyp_[t][hypo];
// Change nothing for completed hypothesis
if (prevHyp.token == eos_) {
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
prevHyp.score,
prevHyp.lmState,
&prevHyp,
eos_,
nullptr,
prevHyp.emittingModelScore,
prevHyp.lmScore,
hypo);
continue;
}

const EmittingModelStatePtr& outState = outStates[validHypo];
if (!outState) {
validHypo++;
continue;
}

std::iota(idx.begin(), idx.end(), 0);
if (emittingModelScores[validHypo].size() > opt_.beamSizeToken) {
std::partial_sort(
idx.begin(),
idx.begin() + opt_.beamSizeToken,
idx.end(),
[&emittingModelScores, &validHypo](
const size_t& l, const size_t& r) {
return emittingModelScores[validHypo][l] >
emittingModelScores[validHypo][r];
});
}

for (int r = 0; r <
std::min(emittingModelScores[validHypo].size(),
(size_t)opt_.beamSizeToken);
r++) {
int n = idx[r];
double emittingModelScore = emittingModelScores[validHypo][n];

if (n == eos_) { /* (1) Try eos */
auto lmStateScorePair = lm_->finish(prevHyp.lmState);
auto lmScore = lmStateScorePair.second;

// Iterate through groups, if only one group, just vanilla BS
int hypo = 0;
int validHypo = 0;

uniqueCandidateTokens_.clear();

for (int grp = 0; grp < opt_.numBeamGroups; grp++) {
// Generate new hypothesis
for (hypo, validHypo ; hypo < std::min(hyp_[t].size(), (size_t)grpSize); hypo++) {
const LexiconFreeSeq2SeqDecoderState& prevHyp = hyp_[t][hypo];
// Change nothing for completed hypothesis
if (prevHyp.token == eos_) {
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
prevHyp.score + emittingModelScore + opt_.eosScore +
opt_.lmWeight * lmScore,
lmStateScorePair.first,
prevHyp.score,
prevHyp.lmState,
&prevHyp,
n,
eos_,
nullptr,
prevHyp.emittingModelScore + emittingModelScore,
prevHyp.lmScore + lmScore,
hypo);
} else { /* (2) Try normal token */
auto lmStateScorePair = lm_->score(prevHyp.lmState, n);
auto lmScore = lmStateScorePair.second;
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
prevHyp.score + emittingModelScore + opt_.lmWeight * lmScore,
lmStateScorePair.first,
&prevHyp,
n,
outState,
prevHyp.emittingModelScore + emittingModelScore,
prevHyp.lmScore + lmScore,
prevHyp.emittingModelScore,
prevHyp.lmScore,
hypo);
continue;
}

const EmittingModelStatePtr& outState = outStates[validHypo];
if (!outState) {
validHypo++;
continue;
}

std::iota(idx.begin(), idx.end(), 0);
if (emittingModelScores[validHypo].size() > opt_.beamSizeToken) {
std::partial_sort(
idx.begin(),
idx.begin() + opt_.beamSizeToken,
idx.end(),
[&emittingModelScores, &validHypo](
const size_t& l, const size_t& r) {
return emittingModelScores[validHypo][l] >
emittingModelScores[validHypo][r];
});
}

for (int r = 0; r <
std::min(emittingModelScores[validHypo].size(),
(size_t)opt_.beamSizeToken);
r++) {
int n = idx[r];

double diversityFactor = 0.0;
if (grp > 0) {
// Need to get a set of all the tokens chosen from other groups
// Can only apply the diversity factor after first run through
diversityFactor = diversityFunction_(uniqueCandidateTokens_, n);
}
// Augment log probabilities with diveristy penalty
double emittingModelScore = emittingModelScores[validHypo][n] + diversityFactor;

if (n == eos_) { /* (1) Try eos */
auto lmStateScorePair = lm_->finish(prevHyp.lmState);
auto lmScore = lmStateScorePair.second;

candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
prevHyp.score + emittingModelScore + opt_.eosScore +
opt_.lmWeight * lmScore,
lmStateScorePair.first,
&prevHyp,
n,
nullptr,
prevHyp.emittingModelScore + emittingModelScore,
prevHyp.lmScore + lmScore,
hypo);
} else { /* (2) Try normal token */
auto lmStateScorePair = lm_->score(prevHyp.lmState, n);
auto lmScore = lmStateScorePair.second;
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
prevHyp.score + emittingModelScore + opt_.lmWeight * lmScore,
lmStateScorePair.first,
&prevHyp,
n,
outState,
prevHyp.emittingModelScore + emittingModelScore,
prevHyp.lmScore + lmScore,
hypo);
uniqueCandidateTokens_.insert(n);
}
}
validHypo++;
}
validHypo++;
}
candidatesStore(
candidates_,
Expand Down
13 changes: 10 additions & 3 deletions flashlight/lib/text/decoder/LexiconFreeSeq2SeqDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ struct LexiconFreeSeq2SeqDecoderOptions {
double lmWeight; // Weight of lm
double eosScore; // Score for inserting an EOS
bool logAdd; // If or not use logadd when merging hypothesis
int numBeamGroups = 1; // For diverse beam search, number of beam groups to utilize. Defaults to 1 (non-diverse beam search).
};

/**
Expand Down Expand Up @@ -78,7 +79,7 @@ struct LexiconFreeSeq2SeqDecoderState {
int getWord() const {
return -1;
}
};
};

/**
* Decoder implements a beam seach decoder that finds the token transcription
Expand All @@ -100,12 +101,14 @@ class LexiconFreeSeq2SeqDecoder : public Decoder {
const LMPtr& lm,
const int eos,
EmittingModelUpdateFunc emittingModelUpdateFunc,
const int maxOutputLength)
const int maxOutputLength,
DiversityFunction diversityFunction)
: opt_(std::move(opt)),
lm_(lm),
eos_(eos),
emittingModelUpdateFunc_(emittingModelUpdateFunc),
maxOutputLength_(maxOutputLength) {}
maxOutputLength_(maxOutputLength),
diversityFunction_(diversityFunction) {}

void decodeStep(const float* emissions, int T, int N) override;

Expand All @@ -130,6 +133,10 @@ class LexiconFreeSeq2SeqDecoder : public Decoder {
std::vector<EmittingModelStatePtr> rawPrevStates_;
int maxOutputLength_;

DiversityFunction diversityFunction_;

std::unordered_set<int> uniqueCandidateTokens_;

std::vector<LexiconFreeSeq2SeqDecoderState> candidates_;
std::vector<LexiconFreeSeq2SeqDecoderState*> candidatePtrs_;
double candidatesBestScore_;
Expand Down
2 changes: 2 additions & 0 deletions flashlight/lib/text/decoder/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ using EmittingModelUpdateFunc = std::function<
int& // The current time step being decoded -- 0 --> T
)>;

using DiversityFunction = std::function<double>(const std::vector<int>&, const int);

/* ===================== Candidate-related operations ===================== */

template <class DecoderState>
Expand Down