diff --git a/xllm/core/common/CMakeLists.txt b/xllm/core/common/CMakeLists.txt index 3410b2e5..700c7776 100644 --- a/xllm/core/common/CMakeLists.txt +++ b/xllm/core/common/CMakeLists.txt @@ -15,6 +15,7 @@ cc_library( rate_limiter.h types.h device_monitor.h + version_singleton.h SRCS etcd_client.cpp global_flags.cpp diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index 30b9b4e3..d11019ee 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -389,3 +389,9 @@ DEFINE_string(reasoning_parser, // --- qwen3 reranker config --- DEFINE_bool(enable_qwen3_reranker, false, "Whether to enable qwen3 reranker."); + +DEFINE_bool(enable_constrained_decoding, + false, + "Whether to enable constrained decoding, which is used to ensure " + "that the output meets specific format or structural requirements " + "through pre-defined rules."); \ No newline at end of file diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 5c79a7c3..409cbdbc 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -202,3 +202,5 @@ DECLARE_bool(enable_qwen3_reranker); DECLARE_string(reasoning_parser); DECLARE_bool(enable_shm); + +DECLARE_bool(enable_constrained_decoding); diff --git a/xllm/core/common/types.h b/xllm/core/common/types.h index faa1d11a..da8853e3 100644 --- a/xllm/core/common/types.h +++ b/xllm/core/common/types.h @@ -287,4 +287,8 @@ struct MMChatMessage { std::vector content; }; +inline constexpr int REC_TOKEN_SIZE = 3; + +using RecTokenTriple = std::array; + } // namespace xllm diff --git a/xllm/core/common/version_singleton.h b/xllm/core/common/version_singleton.h new file mode 100644 index 00000000..36eb9147 --- /dev/null +++ b/xllm/core/common/version_singleton.h @@ -0,0 +1,109 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. +Copyright 2024 The ScaleLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace xllm { +// a singleton mode by version +template +class VersionSingleton { + public: + template + static T* GetInstance(const std::string& version, + bool delete_old_versions = true, + int reserved_version_size = + 2, // default retention of the last two versions + Args&&... args) { + T* instance = nullptr; + + { + std::shared_lock lock(instance_map_mutex_); + auto it = instance_map_.find(version); + if (it != instance_map_.end()) { + instance = it->second.get(); + } + } + + if (instance == nullptr) { + std::unique_lock lock(instance_map_mutex_); + + auto it = instance_map_.find(version); + if (it == instance_map_.end()) { + instance = new T(std::forward(args)...); + instance_map_[version] = std::unique_ptr(instance); + instance_version_list_.push_front(version); + if (delete_old_versions) { + if (instance_version_list_.size() > reserved_version_size) { + auto it = instance_version_list_.begin(); + std::advance(it, reserved_version_size); + for (; it != instance_version_list_.end(); it++) { + instance_map_.erase(*it); + } + instance_version_list_.resize(reserved_version_size); + } + } + } else { + instance = it->second.get(); + } + } + + return instance; + } + + static std::vector GetVersions() { + std::lock_guard lock(instance_map_mutex_); + std::vector versions; + for (const auto& pair : instance_map_) { + versions.push_back(pair.first); + } + return versions; + } + + static void DestroyAllInstances() { + std::lock_guard lock(instance_map_mutex_); + instance_map_.clear(); + instance_version_list_.clear(); + } + + VersionSingleton(const VersionSingleton&) = delete; + VersionSingleton& operator=(const VersionSingleton&) = delete; + + private: + VersionSingleton() = default; + ~VersionSingleton() = default; + + static std::unordered_map> instance_map_; + static std::list instance_version_list_; + static std::shared_mutex instance_map_mutex_; +}; + +template +std::unordered_map> + VersionSingleton::instance_map_; +template +std::list VersionSingleton::instance_version_list_; +template +std::shared_mutex VersionSingleton::instance_map_mutex_; + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/hf_model_loader.cpp b/xllm/core/framework/hf_model_loader.cpp index 443ba388..3a07fbe0 100644 --- a/xllm/core/framework/hf_model_loader.cpp +++ b/xllm/core/framework/hf_model_loader.cpp @@ -25,7 +25,10 @@ limitations under the License. #include #include +#include "core/common/version_singleton.h" +#include "core/framework/state_dict/rec_vocab_dict.h" #include "core/framework/tokenizer/fast_tokenizer.h" +#include "core/framework/tokenizer/rec_tokenizer.h" #include "core/framework/tokenizer/sentencepiece_tokenizer.h" #include "core/framework/tokenizer/tiktoken_tokenizer.h" #include "core/framework/tokenizer/tokenizer_factory.h" @@ -50,6 +53,12 @@ HFModelLoader::HFModelLoader(const std::string& model_weights_path) << "Failed to find model weights files in " << model_weights_path; // sort the model weights files by name std::sort(model_weights_files_.begin(), model_weights_files_.end()); + + //@todo: 'false' will be replaced with generative recommendation judgment + if (false) { + CHECK(load_rec_vocab(model_weights_path)) + << "Failed to load rec content from " << model_weights_path; + } } std::unique_ptr HFModelLoader::tokenizer() const { @@ -70,6 +79,28 @@ std::vector>& HFModelLoader::get_state_dicts() { return state_dicts_; } +bool HFModelLoader::load_rec_vocab(const std::string& model_weights_path) { + if (!tokenizer_args_.vocab_file().empty()) { + std::filesystem::path path = model_weights_path; + std::string model_version = path.filename(); + std::string vocab_full_path = + path.append(tokenizer_args_.vocab_file()).string(); + + LOG(INFO) << "model_version:" << model_version + << ", vocab_full_path:" << vocab_full_path; + + CHECK(nullptr != VersionSingleton::GetInstance(model_version)) + << "Failed to get vocab dict instance"; + CHECK(VersionSingleton::GetInstance(model_version) + ->initialize(vocab_full_path)) + << "Failed to initialize vocab dict from " << vocab_full_path; + } else { + LOG(ERROR) << "vocab file is not set"; + } + + return true; +} + bool HFModelLoader::load_args(const std::string& model_weights_path) { if (!load_model_args(model_weights_path)) { LOG(ERROR) << "Failed to load model args from " << model_weights_path; diff --git a/xllm/core/framework/hf_model_loader.h b/xllm/core/framework/hf_model_loader.h index eaa8a783..7506db4d 100644 --- a/xllm/core/framework/hf_model_loader.h +++ b/xllm/core/framework/hf_model_loader.h @@ -35,6 +35,7 @@ class HFModelLoader : public ModelLoader { private: bool load_args(const std::string& model_weights_path); + bool load_rec_vocab(const std::string& model_weights_path); bool load_model_args(const std::string& model_weights_path); bool load_quant_args(const std::string& model_weights_path); bool load_tokenizer_args(const std::string& model_weights_path); diff --git a/xllm/core/framework/state_dict/CMakeLists.txt b/xllm/core/framework/state_dict/CMakeLists.txt index 9236338d..e4744b85 100644 --- a/xllm/core/framework/state_dict/CMakeLists.txt +++ b/xllm/core/framework/state_dict/CMakeLists.txt @@ -11,9 +11,11 @@ cc_library( HDRS state_dict.h utils.h + rec_vocab_dict.h SRCS state_dict.cpp utils.cpp + rec_vocab_dict.cpp DEPS rust_safetensors torch diff --git a/xllm/core/framework/state_dict/rec_vocab_dict.cpp b/xllm/core/framework/state_dict/rec_vocab_dict.cpp new file mode 100644 index 00000000..b98aa552 --- /dev/null +++ b/xllm/core/framework/state_dict/rec_vocab_dict.cpp @@ -0,0 +1,138 @@ +#include "rec_vocab_dict.h" + +#include +#include +#include +#include + +#include "common/global_flags.h" +#include "util/timer.h" + +namespace xllm { + +bool RecVocabDict::initialize(const std::string& vocab_file) { + if (initialized_) { + return true; + } + + Timer timer; + + if (vocab_file.empty()) { + LOG(ERROR) << "content data file is empty, file: " << vocab_file; + return false; + } + if (!std::filesystem::exists(vocab_file)) { + LOG(ERROR) << "fail to find content data file: " << vocab_file; + return false; + } + std::ifstream ifs(vocab_file.data(), std::ios::binary | std::ios::ate); + if (!ifs.is_open()) { + LOG(ERROR) << "fail to load content data file: " << vocab_file; + return false; + } + + const size_t file_size = ifs.tellg(); + ifs.seekg(0, std::ios::beg); + + // each line of content : 1 * int64_t(item id) + REC_TOKEN_SIZE * + // int32_t(token id); + const size_t itemid_size = sizeof(int64_t); + const size_t tokens_size = REC_TOKEN_SIZE * sizeof(int32_t); + const size_t line_size = tokens_size + itemid_size; + const size_t estimated_lines = (file_size + line_size - 1) / line_size; + + // 2 and 4 are only empirical values + item_to_tokens_map_.reserve(estimated_lines); + tokens_to_items_map_.reserve(estimated_lines / 2); + prefix_tokens_to_next_tokens_map_.reserve(estimated_lines / 4); + + int64_t item_id = 0; + RecTokenTriple tokens; + + while (ifs.read(reinterpret_cast(&item_id), itemid_size) && + ifs.read(reinterpret_cast(tokens.data()), tokens_size)) { + if (FLAGS_enable_constrained_decoding) { + for (int i = 0; i < tokens.size(); i++) { + std::vector prefix_tokens; + + for (int j = 0; j < i; j++) { + prefix_tokens.emplace_back(tokens[j]); + } + + prefix_tokens_to_next_tokens_map_[prefix_tokens].insert(tokens[i]); + } + } + + item_to_tokens_map_[item_id] = tokens; + + tokens_to_items_map_[tokens].emplace_back(item_id); + } + + if (ifs.gcount() != 0 && ifs.gcount() != line_size) { + LOG(ERROR) << "possibly containing incomplete lines : " << vocab_file; + item_to_tokens_map_.clear(); + tokens_to_items_map_.clear(); + prefix_tokens_to_next_tokens_map_.clear(); + return false; + } + + initialized_ = true; + LOG(INFO) << "total line size:" << estimated_lines + << ",parse tokens to item id map size: " + << tokens_to_items_map_.size() + << ", parse item to tokens map size:" << item_to_tokens_map_.size() + << ", parse prefix tokens to next tokens map size:" + << prefix_tokens_to_next_tokens_map_.size() + << ", cost: " << timer.elapsed_seconds() << " seconds"; + + return true; +} + +bool RecVocabDict::get_items_by_tokens(const RecTokenTriple& rec_token_triple, + std::vector* item_ids) const { + CHECK_EQ(initialized_, true); + CHECK_NE(item_ids, nullptr); + + auto iter = tokens_to_items_map_.find(rec_token_triple); + if (iter == tokens_to_items_map_.end()) { + return false; + } + + std::copy( + iter->second.begin(), iter->second.end(), std::back_inserter(*item_ids)); + + return true; +} + +bool RecVocabDict::get_tokens_by_item(int64_t item_id, + std::vector* token_ids) const { + CHECK_EQ(initialized_, true); + CHECK_NE(token_ids, nullptr); + + auto iter = item_to_tokens_map_.find(item_id); + if (iter == item_to_tokens_map_.end()) { + return false; + } + + std::copy( + iter->second.begin(), iter->second.end(), std::back_inserter(*token_ids)); + + return true; +} + +const std::set& RecVocabDict::get_next_tokens_by_prefix_tokens( + const Slice& prefix_token_ids) const { + CHECK_EQ(initialized_, true); + CHECK_LT(prefix_token_ids.size(), REC_TOKEN_SIZE); + + std::vector prefix_tokens_ids_vec = prefix_token_ids; + auto iter = prefix_tokens_to_next_tokens_map_.find(prefix_tokens_ids_vec); + if (iter == prefix_tokens_to_next_tokens_map_.end()) { + static std::set empty_set; + return empty_set; + } + + return iter->second; +} + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/state_dict/rec_vocab_dict.h b/xllm/core/framework/state_dict/rec_vocab_dict.h new file mode 100644 index 00000000..62d5a5ef --- /dev/null +++ b/xllm/core/framework/state_dict/rec_vocab_dict.h @@ -0,0 +1,110 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. +Copyright 2024 The ScaleLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "common/macros.h" +#include "common/types.h" +#include "util/slice.h" + +namespace xllm { +// a vocab dictionary in generative recommendation scenarios, used for mapping +// token IDs and item IDs, currently updated with the model version, and +// real-time updates are not supported. +class RecVocabDict final { + public: + RecVocabDict() = default; + + ~RecVocabDict() { + initialized_ = false; + item_to_tokens_map_.clear(); + tokens_to_items_map_.clear(); + prefix_tokens_to_next_tokens_map_.clear(); + } + + /** + * @brief initialize instance, parse vocab file + * @param vocab_file vocab file, need full path + * @return true represents successful initialization, false represents failed + * initialization + */ + bool initialize(const std::string& vocab_file); + + /** + * @brief get the corresponding item ID list through a token ID triplet + * @param token_ids, a token ID triplet, so token_ids size must be three + * @param item_ids, output mapping item id list + * @return true represents successful gain, false represents failed gain + */ + bool get_items_by_tokens(const RecTokenTriple& rec_token_triple, + std::vector* item_ids) const; + + /** + * @brief get the corresponding token ID triplet through a item id + * @param item_ids, input item id + * @param token_ids, output mapping token id triplet, so token_ids size will + * be three + * @return true represents successful gain, false represents failed gain + */ + bool get_tokens_by_item(int64_t item_id, + std::vector* token_ids) const; + + /** + * @brief get all next token id list through the prefix token id list, for + * example, in the vocab file, there are these token id triplets, 1-2-3, + * 1-2-4, 7-8-9, if prefix the token id is [1], then the next token id list + * is [2], if the prefix token id is [1,2], then the next token id list is + * [3,4] + * @param prefix_token_ids, prefix token id list, the size must be less then + * three + * @attention if prefix_token_ids size is zero, will return all first token of + * the token triplets + * @return next token id list + */ + const std::set& get_next_tokens_by_prefix_tokens( + const Slice& prefix_token_ids) const; + + private: + // check if initialization has been successful + bool initialized_ = false; + + // convert token to item map, key: token id triplet, value: item id list, + // there is a token id triplet corresponding to multiple item IDs, and + // boost::hash will generate ordered triplet hash value + std::unordered_map, + boost::hash> + tokens_to_items_map_; + + // convert item to tokens map, key: item id, value: token triplet, there is a + // item id corresponding to a token id triplet + std::unordered_map item_to_tokens_map_; + + // convert prifix tokens to next tokens map, key: prefix token id list, value: + // next token id list + std::unordered_map, + std::set, + boost::hash>> + prefix_tokens_to_next_tokens_map_; +}; +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/tokenizer/CMakeLists.txt b/xllm/core/framework/tokenizer/CMakeLists.txt index 3ee527e8..93403693 100644 --- a/xllm/core/framework/tokenizer/CMakeLists.txt +++ b/xllm/core/framework/tokenizer/CMakeLists.txt @@ -14,12 +14,14 @@ cc_library( sentencepiece_tokenizer.h fast_tokenizer.h tokenizer_proxy.h + rec_tokenizer.h SRCS tokenizer_factory.cpp tiktoken_tokenizer.cpp sentencepiece_tokenizer.cpp fast_tokenizer.cpp tokenizer_proxy.cpp + rec_tokenizer.cpp DEPS :common :sentencepiece diff --git a/xllm/core/framework/tokenizer/rec_tokenizer.cpp b/xllm/core/framework/tokenizer/rec_tokenizer.cpp new file mode 100644 index 00000000..64375be8 --- /dev/null +++ b/xllm/core/framework/tokenizer/rec_tokenizer.cpp @@ -0,0 +1,52 @@ +#include "rec_tokenizer.h" + +#include + +#include "common/version_singleton.h" +#include "state_dict/rec_vocab_dict.h" + +namespace xllm { +RecTokenizer::RecTokenizer(const std::string_view& dir_path, + const TokenizerArgs& args) { + args_ = args; + dir_path_ = dir_path; + model_version_ = std::filesystem::path(dir_path).filename(); +} + +bool RecTokenizer::encode(int64_t item_id, + std::vector* token_ids) const { + if (!VersionSingleton::GetInstance(model_version_) + ->get_tokens_by_item(item_id, token_ids)) { + return false; + } + + return true; +} + +bool RecTokenizer::decode(const Slice& token_ids, + bool skip_special_tokens, + std::vector* item_ids) const { + CHECK_EQ(token_ids.size(), REC_TOKEN_SIZE); + + RecTokenTriple rec_token_triple; + std::copy(token_ids.begin(), token_ids.end(), rec_token_triple.begin()); + + if (!VersionSingleton::GetInstance(model_version_) + ->get_items_by_tokens(rec_token_triple, item_ids)) { + return false; + } + + return true; +} + +size_t RecTokenizer::vocab_size() const { + // currently, there is no voice size set in the tokenizer configuration. The + // voice size can be obtained from the model args + return 0; +} + +std::unique_ptr RecTokenizer::clone() const { + return std::make_unique(dir_path_, args_); +} + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/tokenizer/rec_tokenizer.h b/xllm/core/framework/tokenizer/rec_tokenizer.h new file mode 100644 index 00000000..41b03f0b --- /dev/null +++ b/xllm/core/framework/tokenizer/rec_tokenizer.h @@ -0,0 +1,56 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. +Copyright 2024 The ScaleLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "tokenizer.h" +#include "tokenizer_args.h" +#include "util/slice.h" + +namespace xllm { + +class RecTokenizer : public Tokenizer { + public: + RecTokenizer(const std::string_view& dir_path, const TokenizerArgs& args); + + virtual ~RecTokenizer() = default; + + bool encode(int64_t item_id, std::vector* token_ids) const override; + + bool decode(const Slice& token_ids, + bool skip_special_tokens, + std::vector* item_ids) const override; + + size_t vocab_size() const override; + + std::unique_ptr clone() const override; + + private: + TokenizerArgs args_; + + std::string dir_path_; + + std::string model_version_; +}; + +} // namespace xllm diff --git a/xllm/core/framework/tokenizer/tokenizer.h b/xllm/core/framework/tokenizer/tokenizer.h index 995f04f5..0b4b86b3 100644 --- a/xllm/core/framework/tokenizer/tokenizer.h +++ b/xllm/core/framework/tokenizer/tokenizer.h @@ -30,7 +30,9 @@ class Tokenizer { virtual ~Tokenizer() = default; virtual bool encode(const std::string_view& text, - std::vector* ids) const = 0; + std::vector* ids) const { + return false; + } virtual bool batch_encode(const std::vector& texts, std::vector>* ids) const { @@ -45,16 +47,31 @@ class Tokenizer { } virtual std::string decode(const Slice& ids, - bool skip_special_tokens) const = 0; + bool skip_special_tokens) const { + return ""; + } + + // only for generative recommendation + virtual bool encode(int64_t item_id, std::vector* token_ids) const { + return false; + } + // only for generative recommendation + virtual bool decode(const Slice& token_ids, + bool skip_special_tokens, + std::vector* item_ids) const { + return false; + } virtual std::optional token_to_id( - const std::string_view& token) const = 0; + const std::string_view& token) const { + return std::nullopt; + } - virtual std::string id_to_token(int32_t id) const = 0; + virtual std::string id_to_token(int32_t id) const { return ""; } - virtual size_t vocab_size() const = 0; + virtual size_t vocab_size() const { return 0; } - virtual std::unique_ptr clone() const = 0; + virtual std::unique_ptr clone() const { return nullptr; } }; } // namespace xllm diff --git a/xllm/core/framework/tokenizer/tokenizer_factory.cpp b/xllm/core/framework/tokenizer/tokenizer_factory.cpp index 0f56291b..806fae48 100644 --- a/xllm/core/framework/tokenizer/tokenizer_factory.cpp +++ b/xllm/core/framework/tokenizer/tokenizer_factory.cpp @@ -36,12 +36,18 @@ std::unique_ptr TokenizerFactory::create_tokenizer( LOG(INFO) << "Create Tiktoken tokenizer."; tokenizer = std::make_unique(model_weights_path, tokenizer_args); + } else if (tokenizer_args.tokenizer_type() == "rec") { + // 3. create rec tokenizer + LOG(INFO) << "Create rec tokenizer."; + tokenizer = + std::make_unique(model_weights_path, tokenizer_args); } else { - // 3. create sentencepiece tokenizer + // 4. create sentencepiece tokenizer LOG(INFO) << "Create SentencePiece tokenizer."; tokenizer = std::make_unique(model_weights_path, tokenizer_args); } + if (proxy) { return std::make_unique(std::move(tokenizer)); } diff --git a/xllm/core/framework/tokenizer/tokenizer_factory.h b/xllm/core/framework/tokenizer/tokenizer_factory.h index f7d9886c..8c2b2c73 100644 --- a/xllm/core/framework/tokenizer/tokenizer_factory.h +++ b/xllm/core/framework/tokenizer/tokenizer_factory.h @@ -16,6 +16,7 @@ limitations under the License. #pragma once #include "fast_tokenizer.h" +#include "rec_tokenizer.h" #include "sentencepiece_tokenizer.h" #include "tiktoken_tokenizer.h" #include "tokenizer_args.h"