diff --git a/docs/design/local-embedding-llama-cpp-design.md b/docs/design/local-embedding-llama-cpp-design.md new file mode 100644 index 000000000..c097dab4d --- /dev/null +++ b/docs/design/local-embedding-llama-cpp-design.md @@ -0,0 +1,449 @@ +# OpenViking 本地 Embedding Llama-cpp 设计文档 + +Date: 2026-04-11 +Status: 已批准进入实现 + +## 目标 + +为 OpenViking 增加内置的本地 dense embedding 能力,并满足以下产品行为: + +- 当用户没有显式配置 `embedding` 时,OpenViking 默认使用本地 embedding backend +- 默认本地模型为 `bge-small-zh-v1.5-f16` +- 本地推理基于 `llama-cpp-python` 加载 GGUF 模型 +- 本地推理依赖不放入主依赖,而是通过 optional extra 单独分发,降低安装风险 + +最终效果是:在不改变“默认走本地 embedding”这一产品目标的前提下,让方案具备可实现性和可维护性。 + +## 范围 + +本次设计包含: + +- 新增 `embedding.backend = "local"` backend +- 当用户未提供 embedding 配置时,自动生成隐式本地 embedding 配置 +- 基于 `llama-cpp-python` 的 dense embedder +- 默认模型 `bge-small-zh-v1.5-f16` +- 模型路径解析、下载和缓存目录管理 +- query/document 双路编码语义 +- collection 元数据校验和 rebuild 规则 +- 启动期校验与错误提示 +- 测试方案和 benchmark 预留 + +本次设计不包含: + +- 本地 sparse embedding +- 本地 hybrid embedding +- 本地失败后静默回退到远程 provider +- 运行时自动安装依赖 +- 替换现有远程 provider + +## 决策摘要 + +OpenViking 采用以下组合策略: + +1. 产品默认行为:如果用户没有配置 embedding,OpenViking 会隐式选择本地 embedding backend。 +2. 依赖分发策略:`llama-cpp-python` 不进入主依赖,而是通过 `openviking[local-embed]` 之类的 optional extra 分发。 +3. 默认本地模型:`bge-small-zh-v1.5-f16`。 +4. 失败策略:如果系统默认选择了本地 embedding,但本地依赖或模型不可用,则直接报错,并给出清晰恢复指引;不会静默回退到远程模型。 + +这和 QMD 的做法不完全相同。QMD 是 Node CLI 产品,可以把 `node-llama-cpp` 作为主依赖;而 OpenViking 是 Python SDK 和服务组件,如果让原生依赖阻断主包安装,代价会更高。 + +## 为什么这样设计 + +调研结论和当前代码约束基本指向同一个方向: + +- QMD 证明了“默认本地 embedding”这个产品方向是成立的。 +- OpenClaw / ArkClaw 证明了基于 GGUF 的本地 memory search 有明确用户价值。 +- OpenViking 当前架构决定了 embedding 初始化失败会在启动期暴露,而不是延后到查询时。 +- 在 Python 生态里,原生依赖失败的成本通常比 QMD 所在的 npm/Node 生态更高。 + +因此,这个设计是在保留产品目标的前提下,尽量缩小原生依赖失败的影响范围。 + +## 用户可见行为 + +### 默认行为 + +如果配置中没有 `embedding`: + +- OpenViking 自动生成一份隐式 local dense embedding 配置 +- backend 设置为 `local` +- model 设置为 `bge-small-zh-v1.5-f16` +- dimension 设置为该模型对应维度 + +用户应感知到的行为是:“本地 embedding 是默认值”。 + +### 显式行为 + +如果用户显式配置了 `embedding`,则始终以显式配置为准,包括: + +- 显式 `backend: "local"` +- 显式远程 backend,例如 `openai`、`volcengine`、`vikingdb` +- 显式 `model_path` +- 显式 `cache_dir` + +不应存在覆盖用户配置的隐式重写。 + +### 安装体验 + +基础安装: + +```bash +pip install openviking +``` + +启用本地 embedding: + +```bash +pip install "openviking[local-embed]" +``` + +如果用户依赖默认本地行为,但没有安装 local extra,系统必须在启动时给出可执行的错误提示,至少包含: + +- 当前默认启用了本地 embedding +- 缺少 `llama-cpp-python` +- 启用本地 embedding 的安装命令 +- 如果用户想改成远程 provider,应如何显式配置 + +## 配置设计 + +在 `EmbeddingModelConfig` 中新增 `local` 作为合法 backend。 + +本地 dense backend 支持的字段: + +- `backend`: `"local"` +- `model`: 逻辑模型名,默认 `bge-small-zh-v1.5-f16` +- `model_path`: 可选,显式指定 GGUF 文件路径 +- `cache_dir`: 可选,缓存根目录,默认 `~/.cache/openviking/models/` +- `dimension`: 可选,但通常应由内置模型注册表推导 +- `batch_size`: 预留给后续批量 embedding + +建议配置示例: + +```json +{ + "embedding": { + "dense": { + "backend": "local", + "model": "bge-small-zh-v1.5-f16", + "cache_dir": "~/.cache/openviking/models" + } + } +} +``` + +显式模型路径示例: + +```json +{ + "embedding": { + "dense": { + "backend": "local", + "model": "bge-small-zh-v1.5-f16", + "model_path": "/data/models/bge-small-zh-v1.5-f16.gguf" + } + } +} +``` + +## 架构设计 + +### 新增组件 + +新增一个本地 dense embedder 实现,例如: + +- `openviking/models/embedder/local_embedders.py` +- `LocalDenseEmbedder` + +其职责包括: + +- 校验 `llama-cpp-python` 是否可用 +- 将逻辑模型名解析为 GGUF 模型规格 +- 解析或下载模型文件 +- 初始化 llama embedding context +- 提供 query/document 双路 embedding 方法 +- 返回模型维度 +- 在 `close()` 时释放本地资源 + +### Factory 与配置改造 + +需要修改: + +- `EmbeddingModelConfig.validate_config()` 以接受 `backend == "local"` +- `EmbeddingConfig._create_embedder()` 以支持 `("local", "dense")` +- 默认配置生成逻辑,使“缺失 embedding 配置”自动变成 local dense + +### 模型注册表 + +新增一个内置本地模型注册表。第一版可以先做成简单映射,按逻辑模型名索引: + +- 逻辑模型名 +- GGUF 下载 URL 或 HuggingFace 定位信息 +- 预期维度 +- 推荐 prompt 规则 +- 可选的目标文件名 + +首个内置模型为: + +- `bge-small-zh-v1.5-f16` + +## Query / Document 双路编码 + +这部分不是可选优化,而是本方案必须处理的设计点。 + +BGE/E5 一类模型是检索导向模型,通常需要区分: + +- query:用户输入的搜索词或问题 +- document:被存储和检索的文本块 + +OpenViking 当前只有 `embed(text)`,这不足以表达这种语义差异。 + +设计上新增显式接口: + +- `embed_query(text: str) -> EmbedResult` +- `embed_document(text: str) -> EmbedResult` + +为了兼容现有代码,`embed(text)` 可以保留为一个薄封装,但内部必须带角色语义。新的检索代码应调用 `embed_query()`,新的入库代码应调用 `embed_document()`。 + +query/document 的格式规则必须封装在本地 embedder 内部,而不是散落在业务层拼装。 + +## 模型解析与下载流程 + +### 解析顺序 + +1. 如果配置了 `model_path`,直接使用该路径。 +2. 否则通过内置本地模型注册表解析 `model`。 +3. 如果目标文件不存在,则下载到 `cache_dir`。 +4. 用解析后的 GGUF 文件初始化 `llama-cpp-python`。 + +### 缓存目录 + +默认目录: + +- `~/.cache/openviking/models/` + +行为要求: + +- 目录不存在时自动创建 +- 下载后的 GGUF 文件保存在这里 +- 如果目标文件已存在,则不重复下载 + +### 下载策略 + +第一版需要支持: + +- 可读性好的错误输出 +- 稳定可预测的文件命名 +- 失败后可手动重试 + +第一版暂不要求: + +- 断点续传 +- 多镜像源自动切换 +- 后台异步下载器 + +## 启动时机与失败行为 + +OpenViking 当前在 client 启动时就初始化 embedder,本地方案保持这一行为。 + +因此,下面这些问题都会在启动期直接暴露: + +- 没有安装 local extra +- `llama-cpp-python` import 失败 +- 模型文件缺失且下载失败 +- GGUF 文件存在但加载失败 +- 当前 collection 元数据与配置模型不一致 + +### 错误处理规则 + +缺少本地依赖: + +- 直接抛出明确的配置/运行时错误 +- 错误信息中必须包含:缺失包名、安装命令、切换远程 provider 的方法 + +模型下载失败: + +- 抛出包含逻辑模型名、解析 URL、缓存目录和原始异常的错误 + +模型加载失败: + +- 抛出 GGUF 不兼容、文件损坏或当前运行环境不支持的错误 + +元数据不一致: + +- 抛出“当前 embedding 设置与已有索引不兼容,需要 rebuild”的错误 + +不允许静默回退: + +- 本地初始化失败时,不得悄悄切换到 `openai`、`volcengine` 或 `vikingdb` + +## Collection 元数据与重建规则 + +当前系统只在写入时校验向量维度,这在本地模型成为默认值之后是不够的。 + +需要至少持久化以下元数据: + +- `embedding_backend` +- `embedding_model` +- `embedding_dimension` +- `embedding_model_identity` + +其中 `embedding_model_identity` 用于区分“看起来模型名相同,但实际模型文件不同”的情况,可以采用: + +- 解析后的模型路径 +- 模型路径哈希 +- 文件哈希(如果成本可接受) + +### 重建触发条件 + +只要以下任一项发生变化: + +- backend +- model +- dimension +- model identity + +都应判定现有向量不可兼容。系统需要: + +- 在启动时直接报错并提示 rebuild,或 +- 在用户显式触发时执行 rebuild 流程 + +第一版建议采用显式 rebuild,而不是隐式迁移。 + +## 数据流改造 + +### 入库流程 + +当前流程: + +- 语义处理得到文本 +- 队列消费者调用 `embed()` + +改造后流程: + +- 队列消费者调用 `embed_document()` +- 本地 embedder 自动套用 document 侧规则 +- 向量写入时附带与当前模型一致的 collection 元数据 + +### 检索流程 + +当前流程: + +- retriever 调用 `embed()` + +改造后流程: + +- retriever 调用 `embed_query()` +- 本地 embedder 自动套用 query 侧规则 +- 检索时使用与 document 同体系生成的向量 + +## 批量 Embedding 策略 + +第一版可以先保证单条处理正确,但设计上必须明确批量优化路径。 + +阶段一: + +- 在 `LocalDenseEmbedder` 中实现 `embed_batch()` +- 队列层暂时仍允许继续按单条处理 + +阶段二: + +- 在队列侧做消息聚合,一次编码多条待 embedding 文本 + +如果没有批量能力,本地 CPU 索引构建吞吐大概率会明显低于模型 benchmark。 + +## 开发顺序 + +1. 增加 `local` backend 的配置校验和 factory 注册。 +2. 增加“缺失 embedding 配置时默认生成 local dense 配置”的逻辑。 +3. 实现基于 `llama-cpp-python` 的 `LocalDenseEmbedder`。 +4. 增加内置本地模型注册表,并接入 `bge-small-zh-v1.5-f16`。 +5. 增加模型路径解析、缓存目录和自动下载逻辑。 +6. 增加 `embed_query()` / `embed_document()` 双路接口。 +7. 持久化 collection embedding 元数据,并补一致性检查。 +8. 增加 rebuild-required 错误流。 +9. 增加 `embed_batch()` 和 benchmark 基础设施。 +10. 更新用户文档、示例配置和安装说明。 + +## 测试计划 + +### 配置测试 + +- 缺失 `embedding` 时自动生成隐式 local dense 配置 +- 显式远程配置时不触发默认本地逻辑 +- `model_path` 能覆盖逻辑模型解析 +- `cache_dir` 覆盖生效 + +### 依赖与初始化测试 + +- 缺少 `llama-cpp-python` 时,启动错误信息正确 +- 显式 local backend 且依赖已安装时可成功初始化 +- 非法 GGUF 路径会触发模型加载失败 +- 下载失败时错误信息完整可读 + +### Embedding 行为测试 + +- `embed_query()` 与 `embed_document()` 走不同路径 +- 返回维度与模型维度一致 +- `embed_batch()` 的结果顺序和数量正确 + +### 元数据与重建测试 + +- 首次启动时能生成和当前模型一致的元数据 +- 改变 model identity 时会触发需要重建 +- 改变 dimension 时会触发需要重建 + +### 检索回归测试 + +- 中文 query 能正确召回中文文档 +- query/document 双路编码不会破坏现有检索链路 +- 现有远程 provider 行为保持不变 + +### 打包测试 + +- `pip install openviking` 可以在不安装本地依赖的情况下成功 +- `pip install "openviking[local-embed]"` 可以启用本地 import 路径 +- 缺少 extra 且触发默认本地行为时,报错应明确,而不是模糊 import failure + +## 基准测试 + +至少记录以下指标: + +- 依赖已安装且模型已缓存时的启动耗时 +- 首次下载模型时的启动耗时 +- 单条 embedding 延迟 +- 批量 embedding 延迟 +- 在代表性中文语料上的索引构建吞吐 + +在 benchmark 出来之前,不应假设“默认本地 embedding”在所有环境里都同样合适。 + +## 运维说明 + +推荐安装命令: + +```bash +pip install "openviking[local-embed]" +``` + +如果用户想使用远程 embedding: + +- 显式配置 `embedding.dense.backend` +- 提供相应 provider 的凭证 + +## 风险 + +- 原生依赖安装失败 +- 预编译 wheel 覆盖不足 +- GGUF 与运行时版本不兼容 +- 用户预期“零配置”但实际缺少 local extra +- 模型切换后索引不兼容 +- 未做批量聚合时索引吞吐偏低 + +## 交付物 + +- 本地 dense embedder 实现 +- local backend 配置与 factory 集成 +- 内置模型注册表 +- 启动期错误提示与安装指引 +- collection 元数据校验 +- rebuild-required 机制 +- 测试和 benchmark 脚手架 +- 用户文档更新 diff --git a/openviking/models/embedder/__init__.py b/openviking/models/embedder/__init__.py index 9b5be66e5..bdad90d71 100644 --- a/openviking/models/embedder/__init__.py +++ b/openviking/models/embedder/__init__.py @@ -33,6 +33,7 @@ except ImportError: GeminiDenseEmbedder = None # google-genai not installed from openviking.models.embedder.jina_embedders import JinaDenseEmbedder +from openviking.models.embedder.local_embedders import LocalDenseEmbedder try: from openviking.models.embedder.litellm_embedders import LiteLLMDenseEmbedder @@ -66,6 +67,7 @@ "GeminiDenseEmbedder", # Jina AI implementations "JinaDenseEmbedder", + "LocalDenseEmbedder", # LiteLLM implementations "LiteLLMDenseEmbedder", # MiniMax implementations diff --git a/openviking/models/embedder/base.py b/openviking/models/embedder/base.py index f50597344..765559f4c 100644 --- a/openviking/models/embedder/base.py +++ b/openviking/models/embedder/base.py @@ -157,6 +157,22 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes """ return [self.embed(text, is_query=is_query) for text in texts] + def embed_query(self, text: str) -> EmbedResult: + """Embed query text with explicit retrieval-side semantics.""" + return self.embed(text, is_query=True) + + def embed_document(self, text: str) -> EmbedResult: + """Embed document text with explicit indexing-side semantics.""" + return self.embed(text, is_query=False) + + def embed_batch_query(self, texts: List[str]) -> List[EmbedResult]: + """Batch embed query texts.""" + return self.embed_batch(texts, is_query=True) + + def embed_batch_document(self, texts: List[str]) -> List[EmbedResult]: + """Batch embed document texts.""" + return self.embed_batch(texts, is_query=False) + async def embed_async(self, text: str, is_query: bool = False) -> EmbedResult: """Async embed single text. @@ -175,6 +191,18 @@ async def embed_batch_async( results.append(await self.embed_async(text, is_query=is_query)) return results + async def embed_query_async(self, text: str) -> EmbedResult: + return await self.embed_async(text, is_query=True) + + async def embed_document_async(self, text: str) -> EmbedResult: + return await self.embed_async(text, is_query=False) + + async def embed_batch_query_async(self, texts: List[str]) -> List[EmbedResult]: + return await self.embed_batch_async(texts, is_query=True) + + async def embed_batch_document_async(self, texts: List[str]) -> List[EmbedResult]: + return await self.embed_batch_async(texts, is_query=False) + def close(self): """Release resources, subclasses can override as needed""" pass diff --git a/openviking/models/embedder/local_embedders.py b/openviking/models/embedder/local_embedders.py new file mode 100644 index 000000000..0ad37be66 --- /dev/null +++ b/openviking/models/embedder/local_embedders.py @@ -0,0 +1,307 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: AGPL-3.0 +"""Local GGUF embedders powered by llama-cpp-python.""" + +from __future__ import annotations + +import importlib +import logging +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional + +import requests + +from openviking.models.embedder.base import DenseEmbedderBase, EmbedResult +from openviking.storage.errors import EmbeddingConfigurationError + +logger = logging.getLogger(__name__) + +DEFAULT_LOCAL_MODEL_CACHE_DIR = "~/.cache/openviking/models" +DEFAULT_LOCAL_DENSE_MODEL = "bge-small-zh-v1.5-f16" +DEFAULT_BGE_ZH_QUERY_INSTRUCTION = "为这个句子生成表示以用于检索相关文章:" + + +@dataclass(frozen=True) +class LocalModelSpec: + model_name: str + dimension: int + filename: str + download_url: str + query_instruction: Optional[str] = None + + +LOCAL_DENSE_MODEL_SPECS: Dict[str, LocalModelSpec] = { + DEFAULT_LOCAL_DENSE_MODEL: LocalModelSpec( + model_name=DEFAULT_LOCAL_DENSE_MODEL, + dimension=512, + filename="bge-small-zh-v1.5-f16.gguf", + download_url=( + "https://huggingface.co/CompendiumLabs/bge-small-zh-v1.5-gguf/resolve/main/" + "bge-small-zh-v1.5-f16.gguf?download=true" + ), + query_instruction=DEFAULT_BGE_ZH_QUERY_INSTRUCTION, + ) +} + + +def get_local_model_spec(model_name: str) -> LocalModelSpec: + try: + return LOCAL_DENSE_MODEL_SPECS[model_name] + except KeyError as exc: + raise ValueError( + f"Unknown local embedding model '{model_name}'. " + f"Supported models: {list(LOCAL_DENSE_MODEL_SPECS.keys())}" + ) from exc + + +def get_local_model_default_dimension(model_name: str) -> int: + return get_local_model_spec(model_name).dimension + + +def get_local_model_cache_path(model_name: str, cache_dir: Optional[str] = None) -> Path: + spec = get_local_model_spec(model_name) + cache_root = Path(cache_dir or DEFAULT_LOCAL_MODEL_CACHE_DIR).expanduser().resolve() + return cache_root / spec.filename + + +def get_local_model_identity(model_name: str, model_path: Optional[str] = None) -> str: + if model_path: + resolved = Path(model_path).expanduser().resolve() + return str(resolved) + return get_local_model_spec(model_name).filename + + +class LocalDenseEmbedder(DenseEmbedderBase): + """Dense embedder backed by a local GGUF model via llama-cpp-python.""" + + def __init__( + self, + model_name: str = DEFAULT_LOCAL_DENSE_MODEL, + model_path: Optional[str] = None, + cache_dir: Optional[str] = None, + dimension: Optional[int] = None, + query_instruction: Optional[str] = None, + config: Optional[Dict[str, Any]] = None, + ): + runtime_config = dict(config or {}) + runtime_config.setdefault("provider", "local") + super().__init__(model_name, runtime_config) + + self.model_spec = get_local_model_spec(model_name) + self.model_path = model_path + self.cache_dir = cache_dir or DEFAULT_LOCAL_MODEL_CACHE_DIR + self.query_instruction = ( + query_instruction + if query_instruction is not None + else self.model_spec.query_instruction + ) + self._dimension = dimension or self.model_spec.dimension + if self._dimension != self.model_spec.dimension: + raise ValueError( + f"Local model '{model_name}' has fixed dimension {self.model_spec.dimension}, " + f"but got dimension={self._dimension}" + ) + + self._resolved_model_path = self._resolve_model_path() + self._llama = self._load_model() + + def _import_llama(self): + try: + module = importlib.import_module("llama_cpp") + except ImportError as exc: + raise EmbeddingConfigurationError( + "Local embedding is enabled but 'llama-cpp-python' is not installed. " + 'Install it with: pip install "openviking[local-embed]". ' + "If you prefer a remote provider, set embedding.dense.provider explicitly in ov.conf." + ) from exc + + llama_cls = getattr(module, "Llama", None) + if llama_cls is None: + raise EmbeddingConfigurationError( + "llama_cpp.Llama is unavailable in the installed llama-cpp-python package." + ) + return llama_cls + + def _resolve_model_path(self) -> Path: + if self.model_path: + resolved = Path(self.model_path).expanduser().resolve() + if not resolved.exists(): + raise EmbeddingConfigurationError( + f"Local embedding model file not found: {resolved}" + ) + return resolved + + cache_root = Path(self.cache_dir).expanduser().resolve() + cache_root.mkdir(parents=True, exist_ok=True) + target = get_local_model_cache_path(self.model_name, self.cache_dir) + if target.exists(): + return target + + self._download_model(self.model_spec.download_url, target) + return target + + def _download_model(self, url: str, target: Path) -> None: + logger.info("Downloading local embedding model %s to %s", self.model_name, target) + tmp_target = target.with_suffix(target.suffix + ".part") + try: + with requests.get(url, stream=True, timeout=(10, 300)) as response: + response.raise_for_status() + with tmp_target.open("wb") as fh: + for chunk in response.iter_content(chunk_size=1024 * 1024): + if chunk: + fh.write(chunk) + os.replace(tmp_target, target) + except Exception as exc: + tmp_target.unlink(missing_ok=True) + raise EmbeddingConfigurationError( + f"Failed to download local embedding model '{self.model_name}' from {url} " + f"to {target}: {exc}" + ) from exc + + def _load_model(self): + llama_cls = self._import_llama() + try: + return llama_cls( + model_path=str(self._resolved_model_path), + embedding=True, + verbose=False, + ) + except Exception as exc: + raise EmbeddingConfigurationError( + f"Failed to load GGUF embedding model from {self._resolved_model_path}: {exc}" + ) from exc + + def _format_text(self, text: str, *, is_query: bool) -> str: + if is_query and self.query_instruction: + return f"{self.query_instruction}{text}" + return text + + def _supports_native_batch_embeddings(self) -> bool: + context_params = getattr(self._llama, "context_params", None) + n_seq_max = getattr(context_params, "n_seq_max", 1) + return n_seq_max > 1 + + @staticmethod + def _extract_embedding(payload: Any) -> List[float]: + if isinstance(payload, dict): + data = payload.get("data") + if isinstance(data, list) and data: + item = data[0] + if isinstance(item, dict) and "embedding" in item: + return list(item["embedding"]) + if "embedding" in payload: + return list(payload["embedding"]) + raise RuntimeError("Unexpected llama-cpp-python embedding response format") + + @staticmethod + def _extract_embeddings(payload: Any) -> List[List[float]]: + if isinstance(payload, dict): + data = payload.get("data") + if isinstance(data, list): + vectors: List[List[float]] = [] + for item in data: + if not isinstance(item, dict) or "embedding" not in item: + raise RuntimeError( + "Unexpected llama-cpp-python batch embedding response format" + ) + vectors.append(list(item["embedding"])) + return vectors + raise RuntimeError("Unexpected llama-cpp-python batch embedding response format") + + def _embed_formatted_text(self, formatted: str) -> EmbedResult: + payload = self._llama.create_embedding(formatted) + return EmbedResult(dense_vector=self._extract_embedding(payload)) + + def _embed_formatted_texts_sequential(self, formatted: List[str]) -> List[EmbedResult]: + return [ + self._run_with_retry( + lambda formatted_text=text: self._embed_formatted_text(formatted_text), + logger=logger, + operation_name="local sequential batch embedding", + ) + for text in formatted + ] + + def embed(self, text: str, is_query: bool = False) -> EmbedResult: + formatted = self._format_text(text, is_query=is_query) + + try: + result = self._run_with_retry( + lambda: self._embed_formatted_text(formatted), + logger=logger, + operation_name="local embedding", + ) + except Exception as exc: + raise RuntimeError(f"Local embedding failed: {exc}") from exc + + estimated_tokens = self._estimate_tokens(formatted) + self.update_token_usage( + model_name=self.model_name, + provider="local", + prompt_tokens=estimated_tokens, + completion_tokens=0, + ) + return result + + def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: + if not texts: + return [] + + formatted = [self._format_text(text, is_query=is_query) for text in texts] + if len(formatted) > 1 and not self._supports_native_batch_embeddings(): + logger.info( + "Local model %s does not support native multi-sequence batch embedding " + "(n_seq_max <= 1); using sequential mode", + self.model_name, + ) + results = self._embed_formatted_texts_sequential(formatted) + estimated_tokens = sum(self._estimate_tokens(text) for text in formatted) + self.update_token_usage( + model_name=self.model_name, + provider="local", + prompt_tokens=estimated_tokens, + completion_tokens=0, + ) + return results + + def _call_batch() -> List[EmbedResult]: + payload = self._llama.create_embedding(formatted) + return [ + EmbedResult(dense_vector=vector) for vector in self._extract_embeddings(payload) + ] + + try: + results = self._run_with_retry( + _call_batch, + logger=logger, + operation_name="local batch embedding", + ) + except Exception as batch_exc: + logger.warning( + "Local batch embedding failed for model=%s (%s); falling back to sequential embedding", + self.model_name, + batch_exc, + ) + try: + results = self._embed_formatted_texts_sequential(formatted) + except Exception as exc: + raise RuntimeError(f"Local batch embedding failed: {exc}") from exc + + estimated_tokens = sum(self._estimate_tokens(text) for text in formatted) + self.update_token_usage( + model_name=self.model_name, + provider="local", + prompt_tokens=estimated_tokens, + completion_tokens=0, + ) + return results + + def get_dimension(self) -> int: + return self._dimension + + def close(self): + close_fn = getattr(self._llama, "close", None) + if callable(close_fn): + close_fn() diff --git a/openviking/storage/collection_schemas.py b/openviking/storage/collection_schemas.py index d63f1dcc5..3ad5d0c7b 100644 --- a/openviking/storage/collection_schemas.py +++ b/openviking/storage/collection_schemas.py @@ -18,7 +18,11 @@ from openviking.models.embedder.base import EmbedResult, embed_compat from openviking.server.identity import RequestContext, Role -from openviking.storage.errors import CollectionNotFoundError +from openviking.storage.errors import ( + CollectionNotFoundError, + EmbeddingConfigurationError, + EmbeddingRebuildRequiredError, +) from openviking.storage.queuefs.embedding_msg import EmbeddingMsg from openviking.storage.queuefs.named_queue import DequeueHandlerBase from openviking.storage.viking_vector_index_backend import VikingVectorIndexBackend @@ -34,6 +38,7 @@ from openviking_cli.utils.config.open_viking_config import OpenVikingConfig logger = get_logger(__name__) +EMBEDDING_META_MARKER = "\n\n[openviking.embedding]\n" @dataclass @@ -49,7 +54,9 @@ class CollectionSchemas: """ @staticmethod - def context_collection(name: str, vector_dim: int) -> Dict[str, Any]: + def context_collection( + name: str, vector_dim: int, description: Optional[str] = None + ) -> Dict[str, Any]: """ Get the schema for the unified context collection. @@ -118,12 +125,75 @@ def context_collection(name: str, vector_dim: int) -> Dict[str, Any]: ) return { "CollectionName": name, - "Description": "Unified context collection", + "Description": description or "Unified context collection", "Fields": fields, "ScalarIndex": scalar_index, } +def _get_active_embedding_model_config(config: "OpenVikingConfig") -> Any: + embedding_cfg = config.embedding + if embedding_cfg.hybrid is not None: + return embedding_cfg.hybrid + if embedding_cfg.dense is not None: + return embedding_cfg.dense + if embedding_cfg.sparse is not None: + return embedding_cfg.sparse + raise ValueError("No active embedding model configuration found") + + +def _build_embedding_metadata(config: "OpenVikingConfig") -> Dict[str, Any]: + model_cfg = _get_active_embedding_model_config(config) + provider = ( + getattr(model_cfg, "provider", None) or getattr(model_cfg, "backend", None) or "" + ).lower() + model = getattr(model_cfg, "model", None) or "" + dimension = config.embedding.dimension + model_path = getattr(model_cfg, "model_path", None) + model_identity = model + + if provider == "local": + try: + from openviking.models.embedder.local_embedders import get_local_model_identity + + resolved_identity = get_local_model_identity(model, model_path=model_path) + model_identity = str(hashlib.sha256(resolved_identity.encode("utf-8")).hexdigest()) + except Exception: + model_identity = model + + return { + "provider": provider, + "model": model, + "dimension": dimension, + "model_identity": model_identity, + } + + +def _encode_collection_description( + base_description: str, + embedding_meta: Dict[str, Any], +) -> str: + description = (base_description or "Unified context collection").strip() + meta_json = json.dumps(embedding_meta, sort_keys=True, ensure_ascii=False) + return f"{description}{EMBEDDING_META_MARKER}{meta_json}" + + +def _decode_collection_description( + description: Optional[str], +) -> tuple[str, Optional[Dict[str, Any]]]: + text = description or "" + if EMBEDDING_META_MARKER not in text: + return text, None + + base, meta_json = text.split(EMBEDDING_META_MARKER, 1) + try: + payload = json.loads(meta_json.strip()) + except json.JSONDecodeError: + logger.warning("Failed to parse collection embedding metadata from description") + return text, None + return base.strip(), payload if isinstance(payload, dict) else None + + async def init_context_collection(storage) -> bool: """ Initialize the context collection with proper schema. @@ -142,8 +212,53 @@ async def init_context_collection(storage) -> bool: if not name: raise ValueError("Vector DB collection name is required") collection_name = name - schema = CollectionSchemas.context_collection(collection_name, vector_dim) - return await storage.create_collection(collection_name, schema) + embedding_meta = _build_embedding_metadata(config) + schema = CollectionSchemas.context_collection( + collection_name, + vector_dim, + description=_encode_collection_description("Unified context collection", embedding_meta), + ) + created = await storage.create_collection(collection_name, schema) + if created: + return True + + existing_meta = None + if hasattr(storage, "get_collection_meta"): + existing_meta = await storage.get_collection_meta() + + if not existing_meta: + raise EmbeddingConfigurationError( + "Existing collection metadata is unavailable; cannot validate embedding compatibility" + ) + + base_description, existing_embedding_meta = _decode_collection_description( + existing_meta.get("Description") + ) + if existing_embedding_meta == embedding_meta: + return False + + existing_count = await storage.count() if hasattr(storage, "count") else 0 + if existing_embedding_meta is None and existing_count == 0: + if hasattr(storage, "update_collection_description"): + await storage.update_collection_description( + _encode_collection_description( + base_description or "Unified context collection", + embedding_meta, + ) + ) + return False + + if existing_embedding_meta is None: + raise EmbeddingRebuildRequiredError( + "Existing collection is missing embedding metadata and already contains vectors. " + "Please rebuild the collection before continuing, or switch back to the previous embedding config." + ) + + raise EmbeddingRebuildRequiredError( + "Existing collection embedding metadata does not match current configuration. " + f"existing={existing_embedding_meta}, current={embedding_meta}. " + "Rebuild is required before OpenViking can continue, or switch back to the previous embedding config." + ) class TextEmbeddingHandler(DequeueHandlerBase): @@ -326,7 +441,7 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, _embed_t0 = _time.monotonic() result: EmbedResult = await embed_compat( - self._embedder, embedding_msg.message + self._embedder, embedding_msg.message, is_query=False ) _embed_elapsed = _time.monotonic() - _embed_t0 try: diff --git a/openviking/storage/errors.py b/openviking/storage/errors.py index 84840fd27..04c7e3e09 100644 --- a/openviking/storage/errors.py +++ b/openviking/storage/errors.py @@ -31,6 +31,14 @@ class SchemaError(StorageException): """Raised when schema validation fails.""" +class EmbeddingConfigurationError(StorageException): + """Raised when embedding provider setup is invalid or unavailable.""" + + +class EmbeddingRebuildRequiredError(StorageException): + """Raised when existing vector data is incompatible with current embedding config.""" + + class LockError(VikingDBException): """Raised when a lock operation fails.""" diff --git a/openviking/storage/viking_vector_index_backend.py b/openviking/storage/viking_vector_index_backend.py index ce5dc8e7b..f00703cf6 100644 --- a/openviking/storage/viking_vector_index_backend.py +++ b/openviking/storage/viking_vector_index_backend.py @@ -202,6 +202,19 @@ async def get_collection_info(self) -> Optional[Dict[str, Any]]: "status": "active", } + async def get_collection_meta(self) -> Optional[Dict[str, Any]]: + if not await self.collection_exists(): + return None + return self._get_collection().get_meta_data() + + async def update_collection_description(self, description: str) -> bool: + if not await self.collection_exists(): + return False + coll = self._get_collection() + coll.update(description=description) + self._refresh_meta_data(coll) + return True + # ========================================================================= # Data Operations (with tenant enforcement) # ========================================================================= @@ -587,6 +600,12 @@ async def collection_exists_bound(self) -> bool: async def get_collection_info(self) -> Optional[Dict[str, Any]]: return await self._get_default_backend().get_collection_info() + async def get_collection_meta(self) -> Optional[Dict[str, Any]]: + return await self._get_default_backend().get_collection_meta() + + async def update_collection_description(self, description: str) -> bool: + return await self._get_default_backend().update_collection_description(description) + # ========================================================================= # 公开数据操作 API(强制要求 ctx) # ========================================================================= diff --git a/openviking/storage/vikingdb_manager.py b/openviking/storage/vikingdb_manager.py index 899e6a1af..800002cad 100644 --- a/openviking/storage/vikingdb_manager.py +++ b/openviking/storage/vikingdb_manager.py @@ -276,6 +276,12 @@ async def collection_exists_bound(self) -> bool: async def get_collection_info(self) -> Optional[Dict[str, Any]]: return await self._manager.get_collection_info() + async def get_collection_meta(self) -> Optional[Dict[str, Any]]: + return await self._manager.get_collection_meta() + + async def update_collection_description(self, description: str) -> bool: + return await self._manager.update_collection_description(description) + # ========================================================================= # 数据操作 API(自动携带 ctx) # ========================================================================= diff --git a/openviking_cli/doctor.py b/openviking_cli/doctor.py index bcb231aaf..00dbad86b 100644 --- a/openviking_cli/doctor.py +++ b/openviking_cli/doctor.py @@ -77,7 +77,7 @@ def check_config() -> tuple[bool, str, Optional[str]]: except json.JSONDecodeError as exc: return False, f"Invalid JSON in {config_path}", f"Fix syntax error: {exc}" - missing = [key for key in ("embedding",) if key not in data] + missing = [key for key in () if key not in data] if missing: return ( False, @@ -152,13 +152,54 @@ def check_embedding() -> tuple[bool, str, Optional[str]]: if data is None: return False, "Cannot check (config unreadable)", None - embedding = data.get("embedding", {}) - dense = embedding.get("dense", {}) - provider = dense.get("provider", "unknown") - model = dense.get("model", "unknown") + embedding = data.get("embedding", {}) or {} + dense = embedding.get("dense", {}) or {} + provider = dense.get("provider", "local") + model = dense.get("model", "bge-small-zh-v1.5-f16") - if provider == "unknown": - return False, "No embedding provider configured", "Add embedding.dense section to ov.conf" + if provider == "local": + from openviking.models.embedder.local_embedders import ( + get_local_model_cache_path, + get_local_model_spec, + ) + + try: + get_local_model_spec(model) + except ValueError as exc: + return ( + False, + f"{provider}/{model} (unsupported local model)", + str(exc), + ) + + try: + importlib.import_module("llama_cpp") + except ImportError: + return ( + False, + f"{provider}/{model} (missing llama-cpp-python)", + 'pip install "openviking[local-embed]"', + ) + + model_path = dense.get("model_path", "") + cache_dir = Path(dense.get("cache_dir", "~/.cache/openviking/models")).expanduser() + if model_path: + if not Path(model_path).expanduser().exists(): + return ( + False, + f"{provider}/{model} (model_path missing)", + f"Download the GGUF model to {Path(model_path).expanduser()} or update embedding.dense.model_path", + ) + return True, f"{provider}/{model} ({Path(model_path).expanduser()})", None + + cached_file = get_local_model_cache_path(model, str(cache_dir)) + if cached_file.exists(): + return True, f"{provider}/{model} ({cached_file})", None + return ( + True, + f"{provider}/{model} (will auto-download during startup initialization)", + None, + ) api_key = dense.get("api_key", "") if not api_key or api_key.startswith("{"): diff --git a/openviking_cli/utils/config/embedding_config.py b/openviking_cli/utils/config/embedding_config.py index c64068843..dbd3bdac1 100644 --- a/openviking_cli/utils/config/embedding_config.py +++ b/openviking_cli/utils/config/embedding_config.py @@ -37,14 +37,14 @@ class EmbeddingModelConfig(BaseModel): provider: Optional[str] = Field( default="volcengine", description=( - "Provider type: 'openai', 'volcengine', 'vikingdb', 'jina', 'ollama', 'gemini', 'voyage', 'litellm'. " + "Provider type: 'openai', 'volcengine', 'vikingdb', 'jina', 'ollama', 'gemini', 'voyage', 'litellm', 'local'. " "For OpenRouter or other OpenAI-compatible providers, use 'litellm' with " "api_base and api_key, or 'openai' with api_base and extra_headers." ), ) backend: Optional[str] = Field( default="volcengine", - description="Backend type (Deprecated, use 'provider' instead): 'openai', 'volcengine', 'vikingdb', 'voyage'", + description="Backend type (Deprecated, use 'provider' instead): 'openai', 'volcengine', 'vikingdb', 'voyage', 'local'", ) version: Optional[str] = Field(default=None, description="Model version") ak: Optional[str] = Field(default=None, description="Access Key ID for VikingDB API") @@ -63,6 +63,14 @@ class EmbeddingModelConfig(BaseModel): default=None, description="API version for Azure OpenAI (e.g., '2025-01-01-preview').", ) + model_path: Optional[str] = Field( + default=None, + description="Explicit local GGUF model path for provider='local'.", + ) + cache_dir: Optional[str] = Field( + default=None, + description="Local model cache directory for provider='local'.", + ) model_config = {"extra": "forbid"} @@ -105,10 +113,11 @@ def validate_config(self): "minimax", "cohere", "litellm", + "local", ]: raise ValueError( f"Invalid embedding provider: '{self.provider}'. Must be one of: " - "'openai', 'azure', 'volcengine', 'vikingdb', 'jina', 'ollama', 'gemini', 'voyage', 'minimax', 'cohere', 'litellm'" + "'openai', 'azure', 'volcengine', 'vikingdb', 'jina', 'ollama', 'gemini', 'voyage', 'minimax', 'cohere', 'litellm', 'local'" ) # Provider-specific validation @@ -192,6 +201,11 @@ def validate_config(self): "Check your embedding model's documentation for the correct dimension." ) + elif self.provider == "local": + from openviking.models.embedder.local_embedders import get_local_model_spec + + get_local_model_spec(self.model) + return self def get_effective_dimension(self) -> int: @@ -243,6 +257,11 @@ def get_effective_dimension(self) -> int: f"Known models: {list(ollama_model_dimensions.keys())}" ) + if provider == "local": + from openviking.models.embedder.local_embedders import get_local_model_default_dimension + + return get_local_model_default_dimension(self.model) + return 2048 @@ -308,6 +327,22 @@ class EmbeddingConfig(BaseModel): model_config = {"extra": "forbid"} + @model_validator(mode="before") + @classmethod + def apply_default_local_dense(cls, data: Any) -> Any: + if data is None: + data = {} + if not isinstance(data, dict): + return data + + if not data.get("dense") and not data.get("sparse") and not data.get("hybrid"): + data = dict(data) + data["dense"] = { + "provider": "local", + "model": "bge-small-zh-v1.5-f16", + } + return data + @model_validator(mode="after") def validate_config(self): """Validate configuration completeness and consistency""" @@ -345,6 +380,7 @@ def _create_embedder( GeminiDenseEmbedder, JinaDenseEmbedder, LiteLLMDenseEmbedder, + LocalDenseEmbedder, MinimaxDenseEmbedder, OpenAIDenseEmbedder, VikingDBDenseEmbedder, @@ -549,6 +585,16 @@ def _create_embedder( **({"extra_headers": cfg.extra_headers} if cfg.extra_headers else {}), }, ), + ("local", "dense"): ( + LocalDenseEmbedder, + lambda cfg: { + "model_name": cfg.model, + "model_path": cfg.model_path, + "cache_dir": cfg.cache_dir, + "dimension": cfg.dimension, + "config": dict(runtime_config), + }, + ), } key = (provider, embedder_type) diff --git a/pyproject.toml b/pyproject.toml index 6451144c7..949c538d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -174,6 +174,9 @@ benchmark = [ "datasets>=2.0.0", "pandas>=2.0.0", ] +local-embed = [ + "llama-cpp-python>=0.3.0", +] [project.urls] Homepage = "https://github.com/volcengine/openviking" diff --git a/tests/cli/test_doctor.py b/tests/cli/test_doctor.py index 428975053..a677d043f 100644 --- a/tests/cli/test_doctor.py +++ b/tests/cli/test_doctor.py @@ -46,13 +46,13 @@ def test_fail_invalid_json(self, tmp_path: Path): assert not ok assert "Invalid JSON" in detail - def test_fail_missing_embedding_section(self, tmp_path: Path): + def test_pass_without_embedding_section(self, tmp_path: Path): config = tmp_path / "ov.conf" config.write_text(json.dumps({"server": {}})) with patch("openviking_cli.doctor._find_config", return_value=config): ok, detail, fix = check_config() - assert not ok - assert "embedding" in detail + assert ok + assert str(config) in detail class TestCheckPython: @@ -138,6 +138,86 @@ def test_fail_when_missing(self): class TestCheckEmbedding: + def test_fail_local_default_when_optional_dependency_missing(self, tmp_path: Path): + config = tmp_path / "ov.conf" + config.write_text(json.dumps({})) + + with patch("openviking_cli.doctor._find_config", return_value=config): + with patch( + "openviking_cli.doctor.importlib.import_module", + side_effect=ImportError("No module named 'llama_cpp'"), + ): + ok, detail, fix = check_embedding() + + assert not ok + assert "missing llama-cpp-python" in detail + assert "openviking[local-embed]" in fix + + def test_pass_local_default_with_cached_model(self, tmp_path: Path): + config = tmp_path / "ov.conf" + config.write_text(json.dumps({})) + cached_model = ( + Path.home() / ".cache" / "openviking" / "models" / "bge-small-zh-v1.5-f16.gguf" + ) + real_import = __import__ + + with patch("openviking_cli.doctor._find_config", return_value=config): + with patch( + "openviking.models.embedder.local_embedders.get_local_model_cache_path", + return_value=cached_model, + ): + with patch.object(Path, "exists", autospec=True, return_value=True): + with patch( + "openviking_cli.doctor.importlib.import_module", + side_effect=lambda name: object() + if name == "llama_cpp" + else real_import(name), + ): + ok, detail, fix = check_embedding() + + assert ok + assert "local/bge-small-zh-v1.5-f16" in detail + assert fix is None + + def test_pass_local_default_reports_startup_download_when_cache_missing(self, tmp_path: Path): + config = tmp_path / "ov.conf" + config.write_text(json.dumps({})) + real_import = __import__ + + with patch("openviking_cli.doctor._find_config", return_value=config): + with patch.object(Path, "exists", autospec=True, return_value=False): + with patch( + "openviking_cli.doctor.importlib.import_module", + side_effect=lambda name: object() if name == "llama_cpp" else real_import(name), + ): + ok, detail, fix = check_embedding() + + assert ok + assert "startup initialization" in detail + assert fix is None + + def test_fail_local_unknown_model(self, tmp_path: Path): + config = tmp_path / "ov.conf" + config.write_text( + json.dumps( + { + "embedding": { + "dense": { + "provider": "local", + "model": "unknown-local-model", + } + } + } + ) + ) + + with patch("openviking_cli.doctor._find_config", return_value=config): + ok, detail, fix = check_embedding() + + assert not ok + assert "unsupported local model" in detail + assert "Unknown local embedding model" in fix + def test_pass_with_api_key(self, tmp_path: Path): config = tmp_path / "ov.conf" config.write_text( diff --git a/tests/misc/test_config_validation.py b/tests/misc/test_config_validation.py index e2985ac52..6c9b9f23c 100644 --- a/tests/misc/test_config_validation.py +++ b/tests/misc/test_config_validation.py @@ -107,13 +107,15 @@ def test_embedding_validation(): print("Test Embedding config validation") print("=" * 60) - # Test 1: no embedder config + # Test 1: no embedder config -> default local dense print("\n1. Test no embedder config...") try: - _ = EmbeddingConfig() - print(" Should fail but passed") + config = EmbeddingConfig() + print( + f" Pass (default provider={config.dense.provider}, model={config.dense.model}, dim={config.dimension})" + ) except ValueError as e: - print(f" Correctly raised exception: {e}") + print(f" Fail: {e}") # Test 2: OpenAI provider missing api_key print("\n2. Test OpenAI provider missing api_key...") diff --git a/tests/storage/test_collection_schemas.py b/tests/storage/test_collection_schemas.py index 8047a9d15..f39c3d269 100644 --- a/tests/storage/test_collection_schemas.py +++ b/tests/storage/test_collection_schemas.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: AGPL-3.0 import asyncio +import hashlib import inspect import json import logging @@ -13,8 +14,10 @@ from openviking.storage.collection_schemas import ( CollectionSchemas, TextEmbeddingHandler, + _build_embedding_metadata, init_context_collection, ) +from openviking.storage.errors import EmbeddingRebuildRequiredError from openviking.storage.queuefs.embedding_msg import EmbeddingMsg from openviking.storage.viking_vector_index_backend import _SingleAccountBackend from openviking_cli.utils.config.vectordb_config import VectorDBBackendConfig @@ -24,7 +27,8 @@ class _DummyEmbedder: def __init__(self): self.calls = 0 - def embed(self, text: str) -> EmbedResult: + def embed(self, text: str, is_query: bool = False) -> EmbedResult: + del is_query self.calls += 1 return EmbedResult(dense_vector=[0.1, 0.2]) @@ -35,6 +39,13 @@ def __init__(self, embedder: _DummyEmbedder, backend: str = "volcengine"): self.embedding = SimpleNamespace( dimension=2, get_embedder=lambda: embedder, + dense=SimpleNamespace( + provider="local", + model="bge-small-zh-v1.5-f16", + model_path=None, + ), + sparse=None, + hybrid=None, circuit_breaker=SimpleNamespace( failure_threshold=5, reset_timeout=60.0, @@ -79,6 +90,109 @@ class _DummyVikingDB: assert handler._circuit_breaker._max_reset_timeout == 600.0 +@pytest.mark.asyncio +async def test_init_context_collection_writes_embedding_metadata(monkeypatch): + captured = {} + + class _FakeStorage: + async def create_collection(self, name, schema): + captured["name"] = name + captured["schema"] = schema + return True + + config = _DummyConfig(_DummyEmbedder()) + monkeypatch.setattr( + "openviking_cli.utils.config.get_openviking_config", + lambda: config, + ) + + created = await init_context_collection(_FakeStorage()) + + assert created is True + description = captured["schema"]["Description"] + assert "[openviking.embedding]" in description + assert '"provider": "local"' in description + assert '"model": "bge-small-zh-v1.5-f16"' in description + + +@pytest.mark.asyncio +async def test_init_context_collection_backfills_metadata_for_empty_legacy_collection(monkeypatch): + updates = [] + + class _FakeStorage: + async def create_collection(self, name, schema): + del name, schema + return False + + async def get_collection_meta(self): + return {"Description": "Unified context collection"} + + async def count(self): + return 0 + + async def update_collection_description(self, description): + updates.append(description) + return True + + config = _DummyConfig(_DummyEmbedder()) + monkeypatch.setattr( + "openviking_cli.utils.config.get_openviking_config", + lambda: config, + ) + + created = await init_context_collection(_FakeStorage()) + + assert created is False + assert len(updates) == 1 + assert '"provider": "local"' in updates[0] + + +@pytest.mark.asyncio +async def test_init_context_collection_rejects_mismatched_nonempty_collection(monkeypatch): + class _FakeStorage: + async def create_collection(self, name, schema): + del name, schema + return False + + async def get_collection_meta(self): + return { + "Description": ( + "Unified context collection\n\n[openviking.embedding]\n" + '{"dimension": 1024, "model": "text-embedding-3-small", ' + '"model_identity": "text-embedding-3-small", "provider": "openai"}' + ) + } + + async def count(self): + return 3 + + async def update_collection_description(self, description): # pragma: no cover + del description + raise AssertionError("should not update mismatched non-empty collection") + + config = _DummyConfig(_DummyEmbedder()) + monkeypatch.setattr( + "openviking_cli.utils.config.get_openviking_config", + lambda: config, + ) + + with pytest.raises(EmbeddingRebuildRequiredError, match="Rebuild is required"): + await init_context_collection(_FakeStorage()) + + +def test_build_embedding_metadata_hashes_resolved_local_model_path(tmp_path): + model_path = tmp_path / ".." / tmp_path.name / "model.gguf" + expected = str(model_path.expanduser().resolve()) + config = _DummyConfig(_DummyEmbedder()) + config.embedding.dense.model_path = str(model_path) + + payload = _build_embedding_metadata(config) + + assert payload["provider"] == "local" + assert payload["model"] == "bge-small-zh-v1.5-f16" + assert payload["model_identity"] == hashlib.sha256(expected.encode("utf-8")).hexdigest() + + @pytest.mark.asyncio async def test_embedding_handler_skip_all_work_when_manager_is_closing(monkeypatch): class _ClosingVikingDB: diff --git a/tests/unit/test_local_embedder.py b/tests/unit/test_local_embedder.py new file mode 100644 index 000000000..a0f0ddb40 --- /dev/null +++ b/tests/unit/test_local_embedder.py @@ -0,0 +1,224 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: AGPL-3.0 + +from pathlib import Path +from types import SimpleNamespace + +import pytest + +from openviking.models.embedder.local_embedders import ( + DEFAULT_BGE_ZH_QUERY_INSTRUCTION, + DEFAULT_LOCAL_DENSE_MODEL, + LocalDenseEmbedder, +) +from openviking.storage.errors import EmbeddingConfigurationError +from openviking_cli.utils.config.embedding_config import EmbeddingConfig, EmbeddingModelConfig + + +class _FakeResponse: + def __init__(self, payload: bytes): + self.payload = payload + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def raise_for_status(self): + return None + + def iter_content(self, chunk_size=1024 * 1024): + del chunk_size + yield self.payload + + +class _FakeLlama: + init_kwargs = [] + inputs = [] + + def __init__(self, **kwargs): + self.__class__.init_kwargs.append(kwargs) + self.context_params = SimpleNamespace(n_seq_max=2) + + def create_embedding(self, payload): + self.__class__.inputs.append(payload) + if isinstance(payload, list): + return { + "data": [ + {"embedding": [float(index)] * 512} + for index, _item in enumerate(payload, start=1) + ] + } + return {"data": [{"embedding": [0.1] * 512}]} + + +class _FakeLlamaFailBatch(_FakeLlama): + def create_embedding(self, payload): + self.__class__.inputs.append(payload) + if isinstance(payload, list) and len(payload) > 1: + raise RuntimeError("llama_decode returned -1") + return {"data": [{"embedding": [0.2] * 512}]} + + +class _FakeLlamaSequentialOnly(_FakeLlama): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.context_params = SimpleNamespace(n_seq_max=1) + + def create_embedding(self, payload): + self.__class__.inputs.append(payload) + if isinstance(payload, list): + raise AssertionError("native batch path should not be used when n_seq_max=1") + return {"data": [{"embedding": [0.3] * 512}]} + + +@pytest.fixture(autouse=True) +def _reset_fake_llama(): + _FakeLlama.init_kwargs = [] + _FakeLlama.inputs = [] + + +def test_embedding_config_defaults_to_local_dense(): + config = EmbeddingConfig() + + assert config.dense is not None + assert config.dense.provider == "local" + assert config.dense.model == DEFAULT_LOCAL_DENSE_MODEL + assert config.dimension == 512 + + +def test_local_embedding_config_rejects_unknown_model(): + with pytest.raises(ValueError, match="Unknown local embedding model"): + EmbeddingModelConfig( + provider="local", + model="unknown-local-model", + ) + + +def test_local_embedder_requires_optional_dependency(monkeypatch, tmp_path): + model_path = tmp_path / "model.gguf" + model_path.write_bytes(b"gguf") + + monkeypatch.setattr( + "openviking.models.embedder.local_embedders.importlib.import_module", + lambda _name: (_ for _ in ()).throw(ImportError("missing llama_cpp")), + ) + + with pytest.raises(EmbeddingConfigurationError, match="openviking\\[local-embed\\]"): + LocalDenseEmbedder(model_path=str(model_path)) + + +def test_local_embedder_uses_explicit_model_path(monkeypatch, tmp_path): + model_path = tmp_path / "model.gguf" + model_path.write_bytes(b"gguf") + + monkeypatch.setattr( + "openviking.models.embedder.local_embedders.importlib.import_module", + lambda _name: SimpleNamespace(Llama=_FakeLlama), + ) + + embedder = LocalDenseEmbedder(model_path=str(model_path)) + + assert Path(_FakeLlama.init_kwargs[-1]["model_path"]) == model_path.resolve() + result = embedder.embed("你好", is_query=False) + assert len(result.dense_vector) == 512 + assert _FakeLlama.inputs[-1] == "你好" + + +def test_local_embedder_downloads_default_model_and_prefixes_query(monkeypatch, tmp_path): + downloaded = {"count": 0} + + def _fake_get(url, stream=True, timeout=(10, 300)): + assert "bge-small-zh-v1.5-f16.gguf" in url + assert stream is True + assert timeout == (10, 300) + downloaded["count"] += 1 + return _FakeResponse(b"gguf") + + monkeypatch.setattr( + "openviking.models.embedder.local_embedders.importlib.import_module", + lambda _name: SimpleNamespace(Llama=_FakeLlama), + ) + monkeypatch.setattr("openviking.models.embedder.local_embedders.requests.get", _fake_get) + + embedder = LocalDenseEmbedder(cache_dir=str(tmp_path)) + + assert downloaded["count"] == 1 + assert (tmp_path / "bge-small-zh-v1.5-f16.gguf").exists() + + result = embedder.embed("测试问题", is_query=True) + assert len(result.dense_vector) == 512 + assert _FakeLlama.inputs[-1] == f"{DEFAULT_BGE_ZH_QUERY_INSTRUCTION}测试问题" + + +def test_local_embedder_embed_batch_preserves_count(monkeypatch, tmp_path): + model_path = tmp_path / "model.gguf" + model_path.write_bytes(b"gguf") + + monkeypatch.setattr( + "openviking.models.embedder.local_embedders.importlib.import_module", + lambda _name: SimpleNamespace(Llama=_FakeLlama), + ) + + embedder = LocalDenseEmbedder(model_path=str(model_path)) + results = embedder.embed_batch(["a", "b"], is_query=True) + + assert len(results) == 2 + assert all(len(item.dense_vector) == 512 for item in results) + assert _FakeLlama.inputs[-1] == [ + f"{DEFAULT_BGE_ZH_QUERY_INSTRUCTION}a", + f"{DEFAULT_BGE_ZH_QUERY_INSTRUCTION}b", + ] + + +def test_local_embedder_embed_batch_falls_back_to_sequential(monkeypatch, tmp_path): + model_path = tmp_path / "model.gguf" + model_path.write_bytes(b"gguf") + + _FakeLlamaFailBatch.init_kwargs = [] + _FakeLlamaFailBatch.inputs = [] + + monkeypatch.setattr( + "openviking.models.embedder.local_embedders.importlib.import_module", + lambda _name: SimpleNamespace(Llama=_FakeLlamaFailBatch), + ) + + embedder = LocalDenseEmbedder(model_path=str(model_path)) + results = embedder.embed_batch(["a", "b"], is_query=True) + + assert len(results) == 2 + assert all(len(item.dense_vector) == 512 for item in results) + assert _FakeLlamaFailBatch.inputs[0] == [ + f"{DEFAULT_BGE_ZH_QUERY_INSTRUCTION}a", + f"{DEFAULT_BGE_ZH_QUERY_INSTRUCTION}b", + ] + assert _FakeLlamaFailBatch.inputs[1:] == [ + f"{DEFAULT_BGE_ZH_QUERY_INSTRUCTION}a", + f"{DEFAULT_BGE_ZH_QUERY_INSTRUCTION}b", + ] + + +def test_local_embedder_embed_batch_uses_sequential_mode_when_native_batch_unsupported( + monkeypatch, tmp_path +): + model_path = tmp_path / "model.gguf" + model_path.write_bytes(b"gguf") + + _FakeLlamaSequentialOnly.init_kwargs = [] + _FakeLlamaSequentialOnly.inputs = [] + + monkeypatch.setattr( + "openviking.models.embedder.local_embedders.importlib.import_module", + lambda _name: SimpleNamespace(Llama=_FakeLlamaSequentialOnly), + ) + + embedder = LocalDenseEmbedder(model_path=str(model_path)) + results = embedder.embed_batch(["a", "b"], is_query=True) + + assert len(results) == 2 + assert all(len(item.dense_vector) == 512 for item in results) + assert _FakeLlamaSequentialOnly.inputs == [ + f"{DEFAULT_BGE_ZH_QUERY_INSTRUCTION}a", + f"{DEFAULT_BGE_ZH_QUERY_INSTRUCTION}b", + ]