Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 15 additions & 69 deletions backend/package/yuxi/config/static/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,54 +169,20 @@ class RerankerInfo(BaseModel):
# ============================================================

DEFAULT_EMBED_MODELS: dict[str, EmbedModelInfo] = {
"siliconflow/BAAI/bge-m3": EmbedModelInfo(
model_id="siliconflow/BAAI/bge-m3",
name="BAAI/bge-m3",
dimension=1024,
base_url="https://api.siliconflow.cn/v1/embeddings",
api_key="SILICONFLOW_API_KEY",
),
"siliconflow/Pro/BAAI/bge-m3": EmbedModelInfo(
model_id="siliconflow/Pro/BAAI/bge-m3",
name="Pro/BAAI/bge-m3",
dimension=1024,
base_url="https://api.siliconflow.cn/v1/embeddings",
api_key="SILICONFLOW_API_KEY",
),
"siliconflow/Qwen/Qwen3-Embedding-0.6B": EmbedModelInfo(
model_id="siliconflow/Qwen/Qwen3-Embedding-0.6B",
name="Qwen/Qwen3-Embedding-0.6B",
dimension=1024,
base_url="https://api.siliconflow.cn/v1/embeddings",
api_key="SILICONFLOW_API_KEY",
),
"vllm/Qwen/Qwen3-Embedding-0.6B": EmbedModelInfo(
model_id="vllm/Qwen/Qwen3-Embedding-0.6B",
name="Qwen3-Embedding-0.6B",
dimension=1024,
base_url="http://localhost:8000/v1/embeddings",
api_key="no_api_key",
),
"ollama/nomic-embed-text": EmbedModelInfo(
model_id="ollama/nomic-embed-text",
name="nomic-embed-text",
dimension=768,
base_url="http://localhost:11434/api/embed",
api_key="no_api_key",
),
"ollama/bge-m3": EmbedModelInfo(
model_id="ollama/bge-m3",
"cusc/bge-m3": EmbedModelInfo(
model_id="cusc/bge-m3",
name="bge-m3",
dimension=1024,
base_url="http://localhost:11434/api/embed",
api_key="no_api_key",
base_url="http://172.31.153.11:8080/cusc-ai/v1/embeddings",
api_key="CUSC_API_KEY",
batch_size=10,
),
"dashscope/text-embedding-v4": EmbedModelInfo(
model_id="dashscope/text-embedding-v4",
name="text-embedding-v4",
dimension=1024,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1/embeddings",
api_key="DASHSCOPE_API_KEY",
"cusc/qwen3-embedding-8b": EmbedModelInfo(
model_id="cusc/qwen3-embedding-8b",
name="qwen3-embedding-8b",
dimension=4096,
base_url="http://172.31.153.11:8080/cusc-ai/v1/embeddings",
api_key="CUSC_API_KEY",
batch_size=10,
),
}
Expand All @@ -227,29 +193,9 @@ class RerankerInfo(BaseModel):
# ============================================================

DEFAULT_RERANKERS: dict[str, RerankerInfo] = {
"siliconflow/BAAI/bge-reranker-v2-m3": RerankerInfo(
name="BAAI/bge-reranker-v2-m3",
base_url="https://api.siliconflow.cn/v1/rerank",
api_key="SILICONFLOW_API_KEY",
),
"siliconflow/Pro/BAAI/bge-reranker-v2-m3": RerankerInfo(
name="Pro/BAAI/bge-reranker-v2-m3",
base_url="https://api.siliconflow.cn/v1/rerank",
api_key="SILICONFLOW_API_KEY",
),
"dashscope/gte-rerank-v2": RerankerInfo(
name="gte-rerank-v2",
base_url="https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
api_key="DASHSCOPE_API_KEY",
),
"dashscope/qwen3-rerank": RerankerInfo(
name="qwen3-rerank",
base_url="https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
api_key="DASHSCOPE_API_KEY",
),
"vllm/BAAI/bge-reranker-v2-m3": RerankerInfo(
name="BAAI/bge-reranker-v2-m3",
base_url="http://localhost:8000/v1/rerank",
api_key="no_api_key",
"cusc/qwen3-reranker-4b": RerankerInfo(
name="cusc/qwen3-reranker-4b",
base_url="http://172.31.153.11:8080/cusc-ai/v1/rerank",
api_key="CUSC_API_KEY",
),
}
27 changes: 27 additions & 0 deletions backend/package/yuxi/knowledge/chunking/ragflow_like/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,33 @@ def count_tokens(text: str) -> int:
return max(1, len(parts)) if text.strip() else 0


def hard_split_by_token_limit(text: str, chunk_token_num: int) -> list[str]:
"""将文本按 token 上限硬切,用于 naive_merge 之后的兜底保护。"""
token_iter = list(re.finditer(r"[A-Za-z0-9_]+|[一-鿿]", text or ""))
if not token_iter:
cleaned = (text or "").strip()
return [cleaned] if cleaned else []

chunks: list[str] = []
start = 0
index = 0
max_tokens = max(int(chunk_token_num or 0), 1)

while index < len(token_iter):
end_index = min(index + max_tokens, len(token_iter)) - 1
end = token_iter[end_index].end()
piece = text[start:end].strip()
if piece:
chunks.append(piece)
start = end
index = end_index + 1

tail = text[start:].strip()
if tail:
chunks.append(tail)
return chunks


def random_choices(arr: list[str], k: int) -> list[str]:
if not arr:
return []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,24 @@ def _iter_sections(markdown_content: str, delimiter: str) -> list[tuple[str, str
return sections


def _ensure_chunk_token_limit(chunks: list[str], chunk_token_num: int) -> list[str]:
"""对输出 chunk 做 token 上限保护:超长的直接硬切。"""
max_tokens = int(chunk_token_num or 0)
if max_tokens <= 0:
return [c.strip() for c in chunks if c and c.strip()]

protected: list[str] = []
for chunk in chunks:
cleaned = (chunk or "").strip()
if not cleaned:
continue
if nlp.count_tokens(cleaned) <= max_tokens:
protected.append(cleaned)
else:
protected.extend(nlp.hard_split_by_token_limit(cleaned, max_tokens))
return protected


def chunk_markdown(markdown_content: str, parser_config: dict[str, Any] | None = None) -> list[str]:
parser_config = parser_config or {}

Expand All @@ -38,9 +56,10 @@ def chunk_markdown(markdown_content: str, parser_config: dict[str, Any] | None =
overlapped_percent = int(parser_config.get("overlapped_percent", 0) or 0)

sections = _iter_sections(markdown_content, delimiter)
return nlp.naive_merge(
chunks = nlp.naive_merge(
sections,
chunk_token_num=chunk_token_num,
delimiter=delimiter,
overlapped_percent=overlapped_percent,
)
return _ensure_chunk_token_limit(chunks, chunk_token_num)
Original file line number Diff line number Diff line change
Expand Up @@ -84,32 +84,6 @@ def _docx_heading_tree(markdown_content: str) -> list[str]:
return [element for element in root.get_tree() if element]


def _hard_split_by_token_limit(text: str, chunk_token_num: int) -> list[str]:
token_iter = list(re.finditer(r"[A-Za-z0-9_]+|[\u4e00-\u9fff]", text or ""))
if not token_iter:
cleaned = (text or "").strip()
return [cleaned] if cleaned else []

chunks: list[str] = []
start = 0
index = 0
max_tokens = max(int(chunk_token_num or 0), 1)

while index < len(token_iter):
end_index = min(index + max_tokens, len(token_iter)) - 1
end = token_iter[end_index].end()
piece = text[start:end].strip()
if piece:
chunks.append(piece)
start = end
index = end_index + 1

tail = text[start:].strip()
if tail:
chunks.append(tail)
return chunks


def _ensure_chunk_token_limit(
chunks: list[str], chunk_token_num: int, delimiter: str, overlapped_percent: int
) -> list[str]:
Expand Down Expand Up @@ -161,7 +135,7 @@ def _ensure_chunk_token_limit(
if nlp.count_tokens(text) <= max_tokens:
protected.append(text)
else:
protected.extend(_hard_split_by_token_limit(text, max_tokens))
protected.extend(nlp.hard_split_by_token_limit(text, max_tokens))

return [chunk for chunk in protected if chunk.strip()]

Expand Down
110 changes: 110 additions & 0 deletions backend/package/yuxi/knowledge/implementations/lightrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,76 @@
from yuxi.utils import hashstr, logger
from yuxi.utils.datetime_utils import utc_isoformat

# Patch LightRAG's tiktoken usage to avoid downloading from the internet at runtime.
# LightRAG calls _get_tiktoken_encoding_for_model() inside openai_embed/openai_complete_if_cache,
# which tries to download tiktoken BPE files from openaipublic.blob.core.windows.net.
# In air-gapped environments this causes ConnectionRefused errors.
import lightrag.llm.openai as _lightrag_openai


class _TiktokenWrapper:
"""Wraps a pre-loaded tiktoken encoding to bypass LightRAG's TiktokenTokenizer auto-download."""

def __init__(self, model_name: str, encoding):
self.model_name = model_name
self._enc = encoding

def encode(self, text: str) -> list[int]:
return self._enc.encode(text)

def decode(self, tokens: list[int]) -> str:
return self._enc.decode(tokens)


class _CharTokenizer:
"""Fallback tokenizer that approximates 1 token ≈ 4 characters (GPT-style ratio)."""

def __init__(self, model_name: str):
self.model_name = model_name

def encode(self, text: str) -> list[int]:
return list(range(0, len(text), 4))

def decode(self, tokens: list[int]) -> str:
return ""


# Patch LightRAG's tiktoken usage to avoid downloading from the internet at runtime.
try:
import tiktoken

_PATCHED_ENCODING = tiktoken.get_encoding("o200k_base")

def _patched_get_encoding(model: str):
if model not in _lightrag_openai._TIKTOKEN_ENCODING_CACHE:
_lightrag_openai._TIKTOKEN_ENCODING_CACHE[model] = _PATCHED_ENCODING
return _lightrag_openai._TIKTOKEN_ENCODING_CACHE[model]

_lightrag_openai._get_tiktoken_encoding_for_model = _patched_get_encoding
logger.info("Patched LightRAG tiktoken to use offline o200k_base encoding")
except Exception as _e:
logger.warning(f"tiktoken unavailable ({_e}), using character-based fallback for LightRAG")

_DUMMY = _CharTokenizer("fallback")

def _fallback_get_encoding(model: str):
if model not in _lightrag_openai._TIKTOKEN_ENCODING_CACHE:
_lightrag_openai._TIKTOKEN_ENCODING_CACHE[model] = _DUMMY
return _lightrag_openai._TIKTOKEN_ENCODING_CACHE[model]

_lightrag_openai._get_tiktoken_encoding_for_model = _fallback_get_encoding


def _extract_json_from_text(text: str) -> str:
"""Extract JSON object from LLM output that may contain reasoning/thinking text."""
import re

if not text:
return text
# Find the first balanced { ... } in the text
match = re.search(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", text, re.DOTALL)
return match.group(0) if match else text


class LightRagKB(KnowledgeBase):
"""基于 LightRAG 的知识库实现"""
Expand Down Expand Up @@ -157,12 +227,23 @@ async def _create_kb_instance(self, db_id: str, kb_config: dict) -> LightRAG:
working_dir = os.path.join(self.work_dir, db_id)
os.makedirs(working_dir, exist_ok=True)

# 创建不依赖外网的 Tokenizer,避免 tiktoken 下载 o200k_base.tiktoken
try:
import tiktoken

enc = tiktoken.get_encoding("o200k_base")
tokenizer = _TiktokenWrapper("gpt-4o-mini", enc)
except Exception:
logger.warning("tiktoken unavailable, falling back to character-based tokenizer")
tokenizer = _CharTokenizer("fallback")

# 创建 LightRAG 实例
rag = LightRAG(
working_dir=working_dir,
workspace=db_id,
llm_model_func=self._get_llm_func(llm_info),
embedding_func=self._get_embedding_func(embed_info),
tokenizer=tokenizer,
vector_storage="MilvusVectorDBStorage",
kv_storage="JsonKVStorage",
graph_storage="Neo4JStorage",
Expand Down Expand Up @@ -247,7 +328,36 @@ def _get_llm_func(self, llm_info: dict):

model = select_model(model_spec=model_spec)

# Reasoning models (MiniMax-M2.x, DeepSeek-R, etc.) wrap JSON in thinking
# text, which breaks LightRAG's keyword_extraction mode that relies on
# chat.completions.parse() + response_format. When keyword_extraction is
# requested we call the model normally and extract JSON from the text.
_REASONING_MODELS = {
"minimax-m2.7",
"minimax-m2.5",
"deepseek-reasoner",
"kimi-k2-thinking",
}

async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs):
is_reasoning = model.model_name.lower() in _REASONING_MODELS
keyword_extraction = kwargs.pop("keyword_extraction", False)

if keyword_extraction and is_reasoning:
text = await openai_complete_if_cache(
model=model.model_name,
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=model.api_key,
base_url=model.base_url,
**kwargs,
)
return _extract_json_from_text(text)

if keyword_extraction:
kwargs["keyword_extraction"] = True

return await openai_complete_if_cache(
model=model.model_name,
prompt=prompt,
Expand Down
Loading