Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
c7e01c8
refactor(shared): change Remote*Request protocol to reference mode
sunnights Apr 26, 2026
d5f5457
feat(knowledge_runtime): add ConfigResolver for reference-mode config…
sunnights Apr 26, 2026
66dcdd9
fix(knowledge_runtime): align hybrid weight handling with Backend beh…
sunnights Apr 26, 2026
6a0a92b
feat(knowledge_runtime): update executors to use ConfigResolver
sunnights Apr 26, 2026
b9d24c2
feat(knowledge_runtime): add database access layer for reference mode
sunnights Apr 26, 2026
d0561b4
refactor(backend): simplify RemoteRagGateway to reference-mode protocol
sunnights Apr 26, 2026
86eedd2
refactor(knowledge_runtime): switch to FastAPI dependency injection f…
sunnights Apr 26, 2026
bbb0a88
refactor(knowledge_runtime): use resolve_admin_config for all admin o…
sunnights Apr 26, 2026
a423715
refactor: extract placeholder utilities to shared/utils/placeholder.py
sunnights Apr 26, 2026
91beb64
refactor(backend): restore ConnectionTestRuntimeSpec and simplify tes…
sunnights Apr 26, 2026
ced8782
refactor(shared): extract runtime config models to runtime_config.py
sunnights Apr 26, 2026
5399bc5
fix(backend): adapt internal/rag.py purge/drop endpoints for referenc…
sunnights Apr 26, 2026
cb44c62
refactor: merge two KR reference-mode branches based on design decisions
sunnights Apr 26, 2026
2d8e3f4
fix(knowledge_runtime): fix broken sentinel pattern in test _make_doc…
sunnights Apr 27, 2026
1a4163e
fix: update test payloads for reference-mode protocol and fix sentinel
sunnights Apr 27, 2026
954d87f
refactor(knowledge_runtime): split test_config_resolver.py into small…
sunnights Apr 27, 2026
85d8791
fix(docker): 修复 knowledge_runtime 数据库连接配置
sunnights Apr 28, 2026
ebd12ba
Merge remote-tracking branch 'wecode-ai/main' into zy_kb2
sunnights Apr 28, 2026
67f4b5c
refactor(rag): 删除 build_internal_* 死代码并统一使用 public 方法
sunnights Apr 28, 2026
c18d116
fix(uv): 修正 requires-python 版本为 3.10
sunnights Apr 28, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 7 additions & 28 deletions backend/app/api/endpoints/adapter/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import asyncio
import logging
from typing import Optional

Expand All @@ -14,17 +15,12 @@
from app.models.user import User
from app.schemas.kind import Retriever
from app.services.adapters.retriever_kinds import retriever_kinds_service
from app.services.rag.gateway_factory import get_query_gateway
from app.services.rag.local_gateway import LocalRagGateway
from app.services.rag.remote_gateway import RemoteRagGatewayError
from app.services.rag.runtime_specs import ConnectionTestRuntimeSpec
from knowledge_engine.storage.factory import (
create_storage_backend_from_config,
get_all_storage_retrieval_methods,
get_supported_retrieval_methods,
get_supported_storage_types,
)
from shared.models import RuntimeRetrieverConfig

# RAG module is heavy (llama_index, scipy, pandas, grpc) - skip in standalone mode

Expand Down Expand Up @@ -259,7 +255,7 @@ async def test_retriever_connection(
}

try:
create_storage_backend_from_config(
storage_backend = create_storage_backend_from_config(
storage_type=storage_type,
url=url,
username=username,
Expand All @@ -268,30 +264,13 @@ async def test_retriever_connection(
index_strategy={"mode": "per_dataset"},
ext={},
)
runtime_spec = ConnectionTestRuntimeSpec(
retriever_config=RuntimeRetrieverConfig(
name="connection-test",
namespace="default",
storage_config={
"type": storage_type,
"url": url,
"username": username,
"password": password,
"apiKey": api_key,
"indexStrategy": {"mode": "per_dataset"},
"ext": {},
},
)
)
gateway = get_query_gateway()
try:
return await gateway.test_connection(runtime_spec)
except RemoteRagGatewayError:
return await LocalRagGateway().test_connection(runtime_spec)

success = await asyncio.to_thread(storage_backend.test_connection)
return {
"success": success,
"message": "Connection successful" if success else "Connection failed",
}
except ValueError as e:
return {"success": False, "message": str(e)}

except Exception as e:
logger.error(f"Retriever connection test failed: {str(e)}")
return {"success": False, "message": f"Connection failed: {str(e)}"}
19 changes: 12 additions & 7 deletions backend/app/api/endpoints/internal/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sqlalchemy.orm import Session

from app.api.dependencies import get_db
from app.services.auth.internal_service_token import verify_internal_service_token
from app.services.knowledge.protected_mediation import (
ProtectedKnowledgeMediationResponse,
protected_knowledge_mediator,
Expand Down Expand Up @@ -47,7 +48,11 @@

logger = logging.getLogger(__name__)

router = APIRouter(prefix="/rag", tags=["internal-rag"])
router = APIRouter(
prefix="/rag",
tags=["internal-rag"],
dependencies=[Depends(verify_internal_service_token)],
)
runtime_resolver = RagRuntimeResolver()


Expand Down Expand Up @@ -700,11 +705,11 @@ async def purge_knowledge_index(
):
"""Delete all indexed chunks for one knowledge base from the local runtime."""
try:
runtime_spec = runtime_resolver.build_internal_purge_index_runtime_spec(
runtime_spec = runtime_resolver.build_public_purge_index_runtime_spec(
db=db,
knowledge_base_id=request.knowledge_base_id,
index_owner_user_id=request.index_owner_user_id,
retriever_config=request.retriever_config.model_dump(mode="python"),
user_id=request.user_id,
user_name=None,
)
return await LocalRagGateway().purge_knowledge_index(runtime_spec, db=db)
except ValueError as e:
Expand All @@ -726,11 +731,11 @@ async def drop_knowledge_index(
):
"""Physically drop the dedicated index/collection for one knowledge base."""
try:
runtime_spec = runtime_resolver.build_internal_drop_index_runtime_spec(
runtime_spec = runtime_resolver.build_public_drop_index_runtime_spec(
db=db,
knowledge_base_id=request.knowledge_base_id,
index_owner_user_id=request.index_owner_user_id,
retriever_config=request.retriever_config.model_dump(mode="python"),
user_id=request.user_id,
user_name=None,
)
return await LocalRagGateway().drop_knowledge_index(runtime_spec, db=db)
except ValueError as e:
Expand Down
4 changes: 4 additions & 0 deletions backend/app/services/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

"""Authentication services."""

from app.services.auth.internal_service_token import (
verify_internal_service_token,
)
from app.services.auth.rag_download_token import (
RagDownloadTokenInfo,
create_rag_download_token,
Expand All @@ -24,6 +27,7 @@
)

__all__ = [
"verify_internal_service_token",
"TaskTokenData",
"TaskTokenInfo",
"create_task_token",
Expand Down
62 changes: 62 additions & 0 deletions backend/app/services/auth/internal_service_token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# SPDX-FileCopyrightText: 2026 Weibo, Inc.
#
# SPDX-License-Identifier: Apache-2.0

"""Internal service token verification for service-to-service endpoints.

Provides FastAPI dependency that validates INTERNAL_SERVICE_TOKEN in
Authorization headers. Used to secure /internal/* endpoints so only
trusted services (chat_shell, knowledge_runtime) can call them.
"""

from __future__ import annotations

import hmac
from typing import Optional

from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer

from app.core.config import settings

# HTTPBearer security scheme for OpenAPI documentation
security = HTTPBearer(auto_error=False)


def verify_internal_service_token(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
) -> None:
"""Verify internal service authentication token.

This dependency checks for a valid Bearer token in the Authorization header.
If INTERNAL_SERVICE_TOKEN is not configured (empty string), authentication is
skipped (dev mode).

Args:
credentials: The Bearer token credentials from the Authorization header.

Raises:
HTTPException: 401 Unauthorized if token is missing or invalid.
"""
expected_token = settings.INTERNAL_SERVICE_TOKEN

# Skip authentication if token is not configured (development mode)
if not expected_token:
return

# Check if credentials are provided
if credentials is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing authentication token",
headers={"WWW-Authenticate": "Bearer"},
)

# Use constant-time comparison to prevent timing attacks
provided_token = credentials.credentials
if not hmac.compare_digest(provided_token, expected_token):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication token",
headers={"WWW-Authenticate": "Bearer"},
)
34 changes: 2 additions & 32 deletions backend/app/services/rag/embedding/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,16 @@
from sqlalchemy.orm import Session

from app.models.kind import Kind
from app.services.chat.config.model_resolver import (
build_default_headers_with_placeholders,
)
from knowledge_engine.embedding.factory import (
create_embedding_model_from_runtime_config as engine_create_embedding_model_from_runtime_config,
)
from shared.models import RuntimeEmbeddingModelConfig
from shared.utils.crypto import decrypt_api_key
from shared.utils.placeholder import process_custom_headers_placeholders

logger = logging.getLogger(__name__)


def _process_custom_headers_placeholders(
custom_headers: Dict[str, Any], user_name: Optional[str] = None
) -> Dict[str, Any]:
"""
Process placeholders in custom headers.

Supports placeholder format: ${user.name}

Args:
custom_headers: Custom headers dict (may contain placeholders)
user_name: User name for placeholder replacement

Returns:
Custom headers with placeholders replaced
"""
if not custom_headers or not isinstance(custom_headers, dict):
return custom_headers

# Build data sources for placeholder replacement
# Only support ${user.name} for now
data_sources: Dict[str, Dict[str, Any]] = {
"user": {"name": user_name or ""},
}

# Use existing build_default_headers_with_placeholders function
return build_default_headers_with_placeholders(custom_headers, data_sources)


def create_embedding_model_from_crd(
db: Session,
user_id: int,
Expand Down Expand Up @@ -165,7 +135,7 @@ def create_embedding_model_from_crd(

# Process placeholders in custom_headers (e.g., ${user.name})
if custom_headers and isinstance(custom_headers, dict):
custom_headers = _process_custom_headers_placeholders(custom_headers, user_name)
custom_headers = process_custom_headers_placeholders(custom_headers, user_name)
logger.info(
f"Processed custom_headers placeholders for embedding_model '{model_name}'"
)
Expand Down
49 changes: 8 additions & 41 deletions backend/app/services/rag/remote_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from app.services.context import context_service
from app.services.rag.content_refs import build_content_ref_for_attachment
from app.services.rag.runtime_specs import (
ConnectionTestRuntimeSpec,
DeleteRuntimeSpec,
DropKnowledgeIndexRuntimeSpec,
IndexRuntimeSpec,
Expand All @@ -33,7 +32,6 @@
RemoteQueryRequest,
RemoteQueryResponse,
RemoteRagError,
RemoteTestConnectionRequest,
)


Expand Down Expand Up @@ -148,27 +146,14 @@ async def index_document(
)
payload = RemoteIndexRequest(
knowledge_base_id=spec.knowledge_base_id,
user_id=spec.index_owner_user_id,
document_id=spec.document_id,
index_owner_user_id=spec.index_owner_user_id,
retriever_config=spec.retriever_config
or {
"name": spec.retriever_name,
"namespace": spec.retriever_namespace,
},
embedding_model_config=spec.embedding_model_config
or {
"model_name": spec.embedding_model_name,
"model_namespace": spec.embedding_model_namespace,
},
splitter_config=spec.splitter_config,
index_families=spec.index_families,
source_file=source_file,
file_extension=file_extension,
content_ref=build_content_ref_for_attachment(
db=db,
attachment_id=spec.source.attachment_id,
),
source_file=source_file,
file_extension=file_extension,
user_name=spec.user_name,
)
return await self._post_model("/internal/rag/index", payload)

Expand All @@ -181,14 +166,11 @@ async def query(
del db
payload = RemoteQueryRequest(
knowledge_base_ids=spec.knowledge_base_ids,
user_id=spec.user_id or 0,
query=spec.query,
max_results=spec.max_results,
document_ids=spec.document_ids,
metadata_condition=spec.metadata_condition,
user_name=spec.user_name,
knowledge_base_configs=spec.knowledge_base_configs,
enabled_index_families=spec.enabled_index_families,
retrieval_policy=spec.retrieval_policy,
)
response_payload = await self._post_model("/internal/rag/query", payload)
response = RemoteQueryResponse.model_validate(response_payload)
Expand All @@ -206,10 +188,8 @@ async def delete_document_index(
del db
payload = RemoteDeleteDocumentIndexRequest(
knowledge_base_id=spec.knowledge_base_id,
user_id=spec.index_owner_user_id,
document_ref=spec.document_ref,
index_owner_user_id=spec.index_owner_user_id,
retriever_config=spec.retriever_config,
enabled_index_families=spec.enabled_index_families,
)
return await self._post_model("/internal/rag/delete-document-index", payload)

Expand All @@ -222,8 +202,7 @@ async def purge_knowledge_index(
del db
payload = RemotePurgeKnowledgeIndexRequest(
knowledge_base_id=spec.knowledge_base_id,
index_owner_user_id=spec.index_owner_user_id,
retriever_config=spec.retriever_config,
user_id=spec.index_owner_user_id,
)
return await self._post_model("/internal/rag/purge-knowledge-index", payload)

Expand All @@ -236,8 +215,7 @@ async def drop_knowledge_index(
del db
payload = RemoteDropKnowledgeIndexRequest(
knowledge_base_id=spec.knowledge_base_id,
index_owner_user_id=spec.index_owner_user_id,
retriever_config=spec.retriever_config,
user_id=spec.index_owner_user_id,
)
return await self._post_model("/internal/rag/drop-knowledge-index", payload)

Expand All @@ -250,8 +228,7 @@ async def list_chunks(
del db
payload = RemoteListChunksRequest(
knowledge_base_id=spec.knowledge_base_id,
index_owner_user_id=spec.index_owner_user_id,
retriever_config=spec.retriever_config,
user_id=spec.index_owner_user_id,
max_chunks=spec.max_chunks,
query=spec.query,
metadata_condition=spec.metadata_condition,
Expand All @@ -260,16 +237,6 @@ async def list_chunks(
response = RemoteListChunksResponse.model_validate(response_payload)
return response.model_dump()

async def test_connection(
self,
spec: ConnectionTestRuntimeSpec,
*,
db: Session | None = None,
) -> dict[str, Any]:
del db
payload = RemoteTestConnectionRequest(retriever_config=spec.retriever_config)
return await self._post_model("/internal/rag/test-connection", payload)


def _get_attachment_source_metadata(
*,
Expand Down
Loading
Loading