diff --git a/src/Lm/Makefile b/src/Lm/Makefile index 861ead682..132c294e8 100644 --- a/src/Lm/Makefile +++ b/src/Lm/Makefile @@ -55,6 +55,10 @@ CXXFLAGS += $(TF_CXXFLAGS) LDFLAGS += $(TF_LDFLAGS) endif +ifdef MODULE_ONNX +LIBSPRINTLM_O += $(OBJDIR)/OnnxStatelessLanguageModel.o +endif + CHECK_O = $(OBJDIR)/check.o \ ../Flf/libSprintFlf.$(a) \ ../Flf/FlfCore/libSprintFlfCore.$(a) \ diff --git a/src/Lm/Module.cc b/src/Lm/Module.cc index 1862c9ad4..c817c1c19 100644 --- a/src/Lm/Module.cc +++ b/src/Lm/Module.cc @@ -38,6 +38,10 @@ #include "ReducedPrecisionCompressedVectorFactory.hh" #endif +#ifdef MODULE_ONNX +#include "OnnxStatelessLanguageModel.hh" +#endif + #include "SimpleHistoryLm.hh" using namespace Lm; @@ -51,7 +55,8 @@ enum LanguageModelType { lmTypeCombine, lmTypeTFRNN, lmTypeCheatingSegment, - lmTypeSimpleHistory + lmTypeSimpleHistory, + lmTypeOnnxStateless }; } @@ -64,6 +69,7 @@ const Core::Choice Module_::lmTypeChoice( "tfrnn", lmTypeTFRNN, "cheating-segment", lmTypeCheatingSegment, "simple-history", lmTypeSimpleHistory, + "onnx-stateless", lmTypeOnnxStateless, Core::Choice::endMark()); const Core::ParameterChoice Module_::lmTypeParam( @@ -91,6 +97,9 @@ Core::Ref Module_::createLanguageModel( case lmTypeTFRNN: result = Core::ref(new TFRecurrentLanguageModel(c, l)); break; #endif case lmTypeSimpleHistory: result = Core::ref(new SimpleHistoryLm(c, l)); break; +#ifdef MODULE_ONNX + case lmTypeOnnxStateless: result = Core::ref(new OnnxStatelessLm(c, l)); break; +#endif default: Core::Application::us()->criticalError("unknwon language model type: %d", lmTypeParam(c)); } diff --git a/src/Lm/OnnxStatelessLanguageModel.cc b/src/Lm/OnnxStatelessLanguageModel.cc new file mode 100644 index 000000000..2d4b69d2b --- /dev/null +++ b/src/Lm/OnnxStatelessLanguageModel.cc @@ -0,0 +1,177 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "OnnxStatelessLanguageModel.hh" + +namespace Lm { + +static const std::vector ioSpec = { + Onnx::IOSpecification{ + "tokens", + Onnx::IODirection::INPUT, + false, + {Onnx::ValueType::TENSOR}, + {Onnx::ValueDataType::INT32}, + {{-1, -1}}}, + Onnx::IOSpecification{ + "lengths", + Onnx::IODirection::INPUT, + false, + {Onnx::ValueType::TENSOR}, + {Onnx::ValueDataType::INT32}, + {{-1}}}, + Onnx::IOSpecification{ + "scores", + Onnx::IODirection::OUTPUT, + false, + {Onnx::ValueType::TENSOR}, + {Onnx::ValueDataType::FLOAT}, + {{-1, -2}}}}; + +const Core::ParameterInt paramMaxBatchSize( + "max-batch-size", + "Maximum number of histories forwarded in one go", + 64, 1); + +OnnxStatelessLm::OnnxStatelessLm(const Core::Configuration& c, Bliss::LexiconRef l) + : Core::Component(c), + Precursor(c, l), + onnxModel_(select("onnx-model"), ioSpec), + inputTokensName_(onnxModel_.mapping.getOnnxName("tokens")), + inputLengthsName_(onnxModel_.mapping.getOnnxName("lengths")), + scoresName_(onnxModel_.mapping.getOnnxName("scores")), + maxBatchSize_(paramMaxBatchSize(config)), + batchQueue_(), + batch_(), + startHistory_() { +} + +void OnnxStatelessLm::load() { + loadVocabulary(); + startHistory_ = startHistory(); +} + +History OnnxStatelessLm::startHistory() const { + if (startHistory_.isValid()) { + return startHistory_; + } + + auto sentBeginId = lexicon_mapping_.at(sentenceBeginToken()->id()); + TokenIdSequence tokenSequence(1ul, sentBeginId); + + auto historyManager = dynamic_cast(historyManager_); + auto handle = historyManager->get(tokenSequence); + auto hist = history(handle); + batchQueue_.push_back(hist); + return hist; +} + +History OnnxStatelessLm::extendedHistory(History const& hist, Token nextToken) const { + auto tokenId = lexicon_mapping_.at(nextToken->id()); + + auto historyManager = dynamic_cast(historyManager_); + auto descriptor = reinterpret_cast(hist.handle()); + + TokenIdSequence newTokens(*descriptor->history); + newTokens.push_back(tokenId); + + auto extHandle = historyManager->get(newTokens); + + auto extHist = history(extHandle); + batchQueue_.push_back(extHist); + return extHist; +} + +Score OnnxStatelessLm::score(History const& hist, Token nextToken) const { + size_t tokenId = lexicon_mapping_.at(nextToken->id()); + + auto descriptor = static_cast(hist.handle()); + + if (descriptor->scores.empty()) { + makeBatch(hist); + scoreBatch(); + batch_.clear(); + } + verify(not descriptor->scores.empty()); + return descriptor->scores[tokenId]; +} + +void OnnxStatelessLm::makeBatch(History const& hist) const { + std::unordered_set seenHistories; + + batch_.push_back(hist); + seenHistories.insert(static_cast(hist.handle())->history.get()); + + while (batch_.size() < maxBatchSize_ and not batchQueue_.empty()) { + auto queuedHistory = batchQueue_.front(); + auto const* queuedDescriptor = static_cast(queuedHistory.handle()); + auto const* queuedTokenSeq = queuedDescriptor->history.get(); + batchQueue_.pop_front(); + + if (seenHistories.find(queuedTokenSeq) == seenHistories.end() and queuedDescriptor->scores.empty()) { + batch_.push_back(queuedHistory); + seenHistories.insert(queuedTokenSeq); + } + } +} + +void OnnxStatelessLm::scoreBatch() const { + if (batch_.empty()) { + return; + } + std::vector descriptors; + descriptors.reserve(batch_.size()); + for (auto const& hist : batch_) { + descriptors.push_back(const_cast(static_cast(hist.handle()))); + } + + size_t maxLength = 0ul; + for (auto* descriptor : descriptors) { + maxLength = std::max(maxLength, descriptor->history->size()); + } + + Math::FastMatrix tokenMat(maxLength, batch_.size()); + Math::FastVector lengthVec(batch_.size()); + + u32 b = 0ul; + for (auto* descriptor : descriptors) { + lengthVec[b] = descriptor->history->size(); + for (u32 n = 0; n < descriptor->history->size(); ++n) { + tokenMat.at(n, b) = descriptor->history->at(n); + } + // zero padding + for (u32 n = descriptor->history->size(); n < maxLength; ++n) { + tokenMat.at(n, b) = 0; + } + ++b; + } + + std::vector> sessionInputs; + sessionInputs.emplace_back(inputTokensName_, Onnx::Value::create(tokenMat, true)); + sessionInputs.emplace_back(inputLengthsName_, Onnx::Value::create(lengthVec)); + + std::vector sessionOutputs; + onnxModel_.session.run(std::move(sessionInputs), {scoresName_}, sessionOutputs); + + Onnx::Value scoreOutput(std::move(sessionOutputs.front())); // Only one session output + + b = 0ul; + for (auto* descriptor : descriptors) { + scoreOutput.get(b, descriptor->scores); + ++b; + } +} + +} // namespace Lm diff --git a/src/Lm/OnnxStatelessLanguageModel.hh b/src/Lm/OnnxStatelessLanguageModel.hh new file mode 100644 index 000000000..0afaf761b --- /dev/null +++ b/src/Lm/OnnxStatelessLanguageModel.hh @@ -0,0 +1,86 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef _LM_ONNX_STATELESS_LM_HH +#define _LM_ONNX_STATELESS_LM_HH + +#include + +#include + +#include "AbstractNNLanguageModel.hh" + +namespace Lm { + +struct NNCacheWithScores : public Lm::NNCacheWithStats { + virtual ~NNCacheWithScores() = default; + + std::vector scores; +}; + +/* + * Simple ONNX Language Model without any state caching. The entire token history is fed into the ONNX model + * for each score request. This trades efficiency for simplicity and flexibility. Thus, it is mostly useful + * for prototyping and models with a relatively small search space. + */ +class OnnxStatelessLm : public AbstractNNLanguageModel { + typedef AbstractNNLanguageModel Precursor; + typedef NNCacheWithScores HistoryDescriptor; + +public: + OnnxStatelessLm(const Core::Configuration& c, Bliss::LexiconRef l); + ~OnnxStatelessLm() = default; + + // Single sentence-begin token + History startHistory() const; + + // Append token to token sequence + History extendedHistory(const History& hist, Token nextToken) const; + + // Scoring by forwarding histories through ONNX model + Score score(const History& hist, Token nextToken) const; + +private: + mutable Onnx::Model onnxModel_; + + std::string inputTokensName_; + std::string inputLengthsName_; + std::string scoresName_; + + size_t maxBatchSize_; + + // When new histories are created through `extendedHistory`, they are put into this queue for batched forwarding + // because it is expected that we need to compute scores for them in the future anyway. + mutable std::deque batchQueue_; + + // Batch of histories which are forwarded at once + mutable std::vector batch_; + + // Cached history object containing only a single sentence-begin token + History startHistory_; + + // Initialize vocabulary and start history + void load(); + + // Creates a batch of histories that contains `hist`` plus additional histories fetched from the `batchQueue_` + void makeBatch(History const& hist) const; + + // Score all histories inside `batch_` + void scoreBatch() const; +}; + +} // namespace Lm + +#endif // _LM_ONNX_SIMPLE_TRANSFORMER_LM_HH