diff --git a/backend/app/api/endpoints/adapter/retrievers.py b/backend/app/api/endpoints/adapter/retrievers.py index 3aa384e7e..9790fb588 100644 --- a/backend/app/api/endpoints/adapter/retrievers.py +++ b/backend/app/api/endpoints/adapter/retrievers.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +import asyncio import logging from typing import Optional @@ -14,17 +15,12 @@ from app.models.user import User from app.schemas.kind import Retriever from app.services.adapters.retriever_kinds import retriever_kinds_service -from app.services.rag.gateway_factory import get_query_gateway -from app.services.rag.local_gateway import LocalRagGateway -from app.services.rag.remote_gateway import RemoteRagGatewayError -from app.services.rag.runtime_specs import ConnectionTestRuntimeSpec from knowledge_engine.storage.factory import ( create_storage_backend_from_config, get_all_storage_retrieval_methods, get_supported_retrieval_methods, get_supported_storage_types, ) -from shared.models import RuntimeRetrieverConfig # RAG module is heavy (llama_index, scipy, pandas, grpc) - skip in standalone mode @@ -259,7 +255,7 @@ async def test_retriever_connection( } try: - create_storage_backend_from_config( + storage_backend = create_storage_backend_from_config( storage_type=storage_type, url=url, username=username, @@ -268,30 +264,13 @@ async def test_retriever_connection( index_strategy={"mode": "per_dataset"}, ext={}, ) - runtime_spec = ConnectionTestRuntimeSpec( - retriever_config=RuntimeRetrieverConfig( - name="connection-test", - namespace="default", - storage_config={ - "type": storage_type, - "url": url, - "username": username, - "password": password, - "apiKey": api_key, - "indexStrategy": {"mode": "per_dataset"}, - "ext": {}, - }, - ) - ) - gateway = get_query_gateway() - try: - return await gateway.test_connection(runtime_spec) - except RemoteRagGatewayError: - return await LocalRagGateway().test_connection(runtime_spec) - + success = await asyncio.to_thread(storage_backend.test_connection) + return { + "success": success, + "message": "Connection successful" if success else "Connection failed", + } except ValueError as e: return {"success": False, "message": str(e)} - except Exception as e: logger.error(f"Retriever connection test failed: {str(e)}") return {"success": False, "message": f"Connection failed: {str(e)}"} diff --git a/backend/app/api/endpoints/internal/rag.py b/backend/app/api/endpoints/internal/rag.py index 04be7244e..119450319 100644 --- a/backend/app/api/endpoints/internal/rag.py +++ b/backend/app/api/endpoints/internal/rag.py @@ -17,6 +17,7 @@ from sqlalchemy.orm import Session from app.api.dependencies import get_db +from app.services.auth.internal_service_token import verify_internal_service_token from app.services.knowledge.protected_mediation import ( ProtectedKnowledgeMediationResponse, protected_knowledge_mediator, @@ -47,7 +48,11 @@ logger = logging.getLogger(__name__) -router = APIRouter(prefix="/rag", tags=["internal-rag"]) +router = APIRouter( + prefix="/rag", + tags=["internal-rag"], + dependencies=[Depends(verify_internal_service_token)], +) runtime_resolver = RagRuntimeResolver() @@ -647,9 +652,11 @@ async def get_all_chunks( All chunks from the knowledge base """ try: - runtime_spec = runtime_resolver.build_internal_list_chunks_runtime_spec( + runtime_spec = runtime_resolver.build_public_list_chunks_runtime_spec( db=db, knowledge_base_id=request.knowledge_base_id, + user_id=request.user_id, + user_name=None, max_chunks=request.max_chunks, query=request.query, metadata_condition=request.metadata_condition, @@ -700,11 +707,11 @@ async def purge_knowledge_index( ): """Delete all indexed chunks for one knowledge base from the local runtime.""" try: - runtime_spec = runtime_resolver.build_internal_purge_index_runtime_spec( + runtime_spec = runtime_resolver.build_public_purge_index_runtime_spec( db=db, knowledge_base_id=request.knowledge_base_id, - index_owner_user_id=request.index_owner_user_id, - retriever_config=request.retriever_config.model_dump(mode="python"), + user_id=request.user_id, + user_name=None, ) return await LocalRagGateway().purge_knowledge_index(runtime_spec, db=db) except ValueError as e: @@ -726,11 +733,11 @@ async def drop_knowledge_index( ): """Physically drop the dedicated index/collection for one knowledge base.""" try: - runtime_spec = runtime_resolver.build_internal_drop_index_runtime_spec( + runtime_spec = runtime_resolver.build_public_drop_index_runtime_spec( db=db, knowledge_base_id=request.knowledge_base_id, - index_owner_user_id=request.index_owner_user_id, - retriever_config=request.retriever_config.model_dump(mode="python"), + user_id=request.user_id, + user_name=None, ) return await LocalRagGateway().drop_knowledge_index(runtime_spec, db=db) except ValueError as e: diff --git a/backend/app/services/auth/__init__.py b/backend/app/services/auth/__init__.py index 456c2f796..5b6f8f46b 100644 --- a/backend/app/services/auth/__init__.py +++ b/backend/app/services/auth/__init__.py @@ -4,6 +4,9 @@ """Authentication services.""" +from app.services.auth.internal_service_token import ( + verify_internal_service_token, +) from app.services.auth.rag_download_token import ( RagDownloadTokenInfo, create_rag_download_token, @@ -24,6 +27,7 @@ ) __all__ = [ + "verify_internal_service_token", "TaskTokenData", "TaskTokenInfo", "create_task_token", diff --git a/backend/app/services/auth/internal_service_token.py b/backend/app/services/auth/internal_service_token.py new file mode 100644 index 000000000..1792c29d1 --- /dev/null +++ b/backend/app/services/auth/internal_service_token.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: 2026 Weibo, Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Internal service token verification for service-to-service endpoints. + +Provides FastAPI dependency that validates INTERNAL_SERVICE_TOKEN in +Authorization headers. Used to secure /internal/* endpoints so only +trusted services (chat_shell, knowledge_runtime) can call them. +""" + +from __future__ import annotations + +import hmac +from typing import Optional + +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + +from app.core.config import settings + +# HTTPBearer security scheme for OpenAPI documentation +security = HTTPBearer(auto_error=False) + + +def verify_internal_service_token( + credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), +) -> None: + """Verify internal service authentication token. + + This dependency checks for a valid Bearer token in the Authorization header. + If INTERNAL_SERVICE_TOKEN is not configured (empty string), authentication is + skipped (dev mode). + + Args: + credentials: The Bearer token credentials from the Authorization header. + + Raises: + HTTPException: 401 Unauthorized if token is missing or invalid. + """ + expected_token = settings.INTERNAL_SERVICE_TOKEN + + # Skip authentication if token is not configured (development mode) + if not expected_token: + return + + # Check if credentials are provided + if credentials is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing authentication token", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # Use constant-time comparison to prevent timing attacks + provided_token = credentials.credentials + if not hmac.compare_digest(provided_token, expected_token): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication token", + headers={"WWW-Authenticate": "Bearer"}, + ) diff --git a/backend/app/services/rag/embedding/factory.py b/backend/app/services/rag/embedding/factory.py index ce1244c60..c78cf2a75 100644 --- a/backend/app/services/rag/embedding/factory.py +++ b/backend/app/services/rag/embedding/factory.py @@ -25,38 +25,11 @@ ) from shared.models import RuntimeEmbeddingModelConfig from shared.utils.crypto import decrypt_api_key +from shared.utils.placeholder import process_custom_headers_placeholders logger = logging.getLogger(__name__) -def _process_custom_headers_placeholders( - custom_headers: Dict[str, Any], user_name: Optional[str] = None -) -> Dict[str, Any]: - """ - Process placeholders in custom headers. - - Supports placeholder format: ${user.name} - - Args: - custom_headers: Custom headers dict (may contain placeholders) - user_name: User name for placeholder replacement - - Returns: - Custom headers with placeholders replaced - """ - if not custom_headers or not isinstance(custom_headers, dict): - return custom_headers - - # Build data sources for placeholder replacement - # Only support ${user.name} for now - data_sources: Dict[str, Dict[str, Any]] = { - "user": {"name": user_name or ""}, - } - - # Use existing build_default_headers_with_placeholders function - return build_default_headers_with_placeholders(custom_headers, data_sources) - - def create_embedding_model_from_crd( db: Session, user_id: int, @@ -169,7 +142,7 @@ def create_embedding_model_from_crd( # Process placeholders in custom_headers (e.g., ${user.name}) if custom_headers and isinstance(custom_headers, dict): - custom_headers = _process_custom_headers_placeholders(custom_headers, user_name) + custom_headers = process_custom_headers_placeholders(custom_headers, user_name) logger.info( f"Processed custom_headers placeholders for embedding_model '{model_name}'" ) diff --git a/backend/app/services/rag/remote_gateway.py b/backend/app/services/rag/remote_gateway.py index 8a940d6d5..a4a48d2ea 100644 --- a/backend/app/services/rag/remote_gateway.py +++ b/backend/app/services/rag/remote_gateway.py @@ -15,7 +15,6 @@ from app.services.context import context_service from app.services.rag.content_refs import build_content_ref_for_attachment from app.services.rag.runtime_specs import ( - ConnectionTestRuntimeSpec, DeleteRuntimeSpec, DropKnowledgeIndexRuntimeSpec, IndexRuntimeSpec, @@ -33,7 +32,6 @@ RemoteQueryRequest, RemoteQueryResponse, RemoteRagError, - RemoteTestConnectionRequest, ) @@ -148,27 +146,14 @@ async def index_document( ) payload = RemoteIndexRequest( knowledge_base_id=spec.knowledge_base_id, + user_id=spec.index_owner_user_id, document_id=spec.document_id, - index_owner_user_id=spec.index_owner_user_id, - retriever_config=spec.retriever_config - or { - "name": spec.retriever_name, - "namespace": spec.retriever_namespace, - }, - embedding_model_config=spec.embedding_model_config - or { - "model_name": spec.embedding_model_name, - "model_namespace": spec.embedding_model_namespace, - }, - splitter_config=spec.splitter_config, - index_families=spec.index_families, + source_file=source_file, + file_extension=file_extension, content_ref=build_content_ref_for_attachment( db=db, attachment_id=spec.source.attachment_id, ), - source_file=source_file, - file_extension=file_extension, - user_name=spec.user_name, ) return await self._post_model("/internal/rag/index", payload) @@ -181,14 +166,11 @@ async def query( del db payload = RemoteQueryRequest( knowledge_base_ids=spec.knowledge_base_ids, + user_id=spec.user_id or 0, query=spec.query, max_results=spec.max_results, document_ids=spec.document_ids, metadata_condition=spec.metadata_condition, - user_name=spec.user_name, - knowledge_base_configs=spec.knowledge_base_configs, - enabled_index_families=spec.enabled_index_families, - retrieval_policy=spec.retrieval_policy, ) response_payload = await self._post_model("/internal/rag/query", payload) response = RemoteQueryResponse.model_validate(response_payload) @@ -206,10 +188,8 @@ async def delete_document_index( del db payload = RemoteDeleteDocumentIndexRequest( knowledge_base_id=spec.knowledge_base_id, + user_id=spec.index_owner_user_id, document_ref=spec.document_ref, - index_owner_user_id=spec.index_owner_user_id, - retriever_config=spec.retriever_config, - enabled_index_families=spec.enabled_index_families, ) return await self._post_model("/internal/rag/delete-document-index", payload) @@ -222,8 +202,7 @@ async def purge_knowledge_index( del db payload = RemotePurgeKnowledgeIndexRequest( knowledge_base_id=spec.knowledge_base_id, - index_owner_user_id=spec.index_owner_user_id, - retriever_config=spec.retriever_config, + user_id=spec.index_owner_user_id, ) return await self._post_model("/internal/rag/purge-knowledge-index", payload) @@ -236,8 +215,7 @@ async def drop_knowledge_index( del db payload = RemoteDropKnowledgeIndexRequest( knowledge_base_id=spec.knowledge_base_id, - index_owner_user_id=spec.index_owner_user_id, - retriever_config=spec.retriever_config, + user_id=spec.index_owner_user_id, ) return await self._post_model("/internal/rag/drop-knowledge-index", payload) @@ -250,8 +228,7 @@ async def list_chunks( del db payload = RemoteListChunksRequest( knowledge_base_id=spec.knowledge_base_id, - index_owner_user_id=spec.index_owner_user_id, - retriever_config=spec.retriever_config, + user_id=spec.index_owner_user_id, max_chunks=spec.max_chunks, query=spec.query, metadata_condition=spec.metadata_condition, @@ -260,16 +237,6 @@ async def list_chunks( response = RemoteListChunksResponse.model_validate(response_payload) return response.model_dump() - async def test_connection( - self, - spec: ConnectionTestRuntimeSpec, - *, - db: Session | None = None, - ) -> dict[str, Any]: - del db - payload = RemoteTestConnectionRequest(retriever_config=spec.retriever_config) - return await self._post_model("/internal/rag/test-connection", payload) - def _get_attachment_source_metadata( *, diff --git a/backend/app/services/rag/retrieval_service.py b/backend/app/services/rag/retrieval_service.py index 4a939cca6..ab4d378cf 100644 --- a/backend/app/services/rag/retrieval_service.py +++ b/backend/app/services/rag/retrieval_service.py @@ -938,11 +938,12 @@ async def get_all_chunks_from_knowledge_base( self, knowledge_base_id: int, db: Session, + user_id: int, max_chunks: int = 10000, query: Optional[str] = None, metadata_condition: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: - """Get all chunks from a knowledge base without permission check. + """Get all chunks from a knowledge base with permission check. This method is used for smart context injection where we need all chunks from a knowledge base to determine if direct injection is possible. @@ -953,6 +954,7 @@ async def get_all_chunks_from_knowledge_base( Args: knowledge_base_id: Knowledge base ID db: Database session + user_id: User ID for permission check max_chunks: Maximum number of chunks to retrieve (safety limit) query: Optional query string for logging purposes metadata_condition: Optional metadata filter conditions @@ -961,16 +963,18 @@ async def get_all_chunks_from_knowledge_base( List of chunk dicts with content, title, chunk_id, doc_ref, metadata Raises: - ValueError: If knowledge base not found or configuration invalid + ValueError: If knowledge base not found, access denied, or configuration invalid """ from app.services.rag.gateway_factory import get_list_chunks_gateway from app.services.rag.runtime_resolver import RagRuntimeResolver # Build runtime spec via resolver runtime_resolver = RagRuntimeResolver() - spec = runtime_resolver.build_internal_list_chunks_runtime_spec( + spec = runtime_resolver.build_public_list_chunks_runtime_spec( db=db, knowledge_base_id=knowledge_base_id, + user_id=user_id, + user_name=None, max_chunks=max_chunks, query=query, metadata_condition=metadata_condition, diff --git a/backend/app/services/rag/runtime_resolver.py b/backend/app/services/rag/runtime_resolver.py index 719977ee6..82f051c5c 100644 --- a/backend/app/services/rag/runtime_resolver.py +++ b/backend/app/services/rag/runtime_resolver.py @@ -11,7 +11,6 @@ get_kb_index_info, get_kb_index_info_by_record, ) -from app.services.rag.embedding.factory import _process_custom_headers_placeholders from app.services.rag.runtime_specs import ( ConnectionTestRuntimeSpec, DeleteRuntimeSpec, @@ -31,6 +30,7 @@ normalize_additional_input_modalities, ) from shared.utils.crypto import decrypt_api_key +from shared.utils.placeholder import process_custom_headers_placeholders class RagRuntimeResolver: @@ -268,28 +268,6 @@ def build_public_list_chunks_runtime_spec( metadata_condition=metadata_condition, ) - def build_internal_list_chunks_runtime_spec( - self, - *, - db: Session, - knowledge_base_id: int, - max_chunks: int, - query: str | None = None, - metadata_condition: dict | None = None, - ) -> ListChunksRuntimeSpec: - kb = self._get_knowledge_base_record(db=db, knowledge_base_id=knowledge_base_id) - if kb is None: - raise ValueError(f"Knowledge base {knowledge_base_id} not found") - - return self._build_list_chunks_runtime_spec( - db=db, - kb=kb, - index_owner_user_id=kb.user_id, - max_chunks=max_chunks, - query=query, - metadata_condition=metadata_condition, - ) - def _build_list_chunks_runtime_spec( self, *, @@ -418,42 +396,6 @@ def build_public_drop_index_runtime_spec( spec_type="drop", ) - def build_internal_purge_index_runtime_spec( - self, - *, - db: Session, - knowledge_base_id: int, - index_owner_user_id: int, - retriever_config: RuntimeRetrieverConfig | dict, - ) -> PurgeKnowledgeRuntimeSpec: - kb = self._get_knowledge_base_record(db=db, knowledge_base_id=knowledge_base_id) - if kb is None: - raise ValueError(f"Knowledge base {knowledge_base_id} not found") - - return PurgeKnowledgeRuntimeSpec( - knowledge_base_id=knowledge_base_id, - index_owner_user_id=index_owner_user_id, - retriever_config=retriever_config, - ) - - def build_internal_drop_index_runtime_spec( - self, - *, - db: Session, - knowledge_base_id: int, - index_owner_user_id: int, - retriever_config: RuntimeRetrieverConfig | dict, - ) -> DropKnowledgeIndexRuntimeSpec: - kb = self._get_knowledge_base_record(db=db, knowledge_base_id=knowledge_base_id) - if kb is None: - raise ValueError(f"Knowledge base {knowledge_base_id} not found") - - return DropKnowledgeIndexRuntimeSpec( - knowledge_base_id=knowledge_base_id, - index_owner_user_id=index_owner_user_id, - retriever_config=retriever_config, - ) - def _build_query_knowledge_base_configs( self, *, @@ -658,7 +600,7 @@ def _build_resolved_embedding_model_config( protocol = spec.get("protocol") or env.get("model") custom_headers = env.get("custom_headers", {}) if custom_headers and isinstance(custom_headers, dict): - custom_headers = _process_custom_headers_placeholders( + custom_headers = process_custom_headers_placeholders( custom_headers, user_name, ) diff --git a/backend/tests/api/endpoints/internal/test_rag_retrieve_endpoint.py b/backend/tests/api/endpoints/internal/test_rag_retrieve_endpoint.py index 4183722ce..feedd4121 100644 --- a/backend/tests/api/endpoints/internal/test_rag_retrieve_endpoint.py +++ b/backend/tests/api/endpoints/internal/test_rag_retrieve_endpoint.py @@ -117,15 +117,7 @@ def test_internal_retrieve_returns_restricted_safe_summary(test_client): def test_internal_all_chunks_routes_protocol_request_through_local_gateway(test_client): payload = { "knowledge_base_id": 7, - "index_owner_user_id": 9, - "retriever_config": { - "name": "retriever-a", - "namespace": "default", - "storage_config": { - "type": "qdrant", - "url": "http://qdrant:6333", - }, - }, + "user_id": 9, "max_chunks": 1000, "query": "list_index_chunks", "metadata_condition": { @@ -138,7 +130,7 @@ def test_internal_all_chunks_routes_protocol_request_through_local_gateway(test_ runtime_spec = object() with ( patch( - "app.api.endpoints.internal.rag.runtime_resolver.build_internal_list_chunks_runtime_spec", + "app.api.endpoints.internal.rag.runtime_resolver.build_public_list_chunks_runtime_spec", return_value=runtime_spec, ) as mock_build_spec, patch( @@ -176,6 +168,8 @@ def test_internal_all_chunks_routes_protocol_request_through_local_gateway(test_ mock_build_spec.assert_called_once_with( db=ANY, knowledge_base_id=7, + user_id=9, + user_name=None, max_chunks=1000, query="list_index_chunks", metadata_condition=payload["metadata_condition"], @@ -188,20 +182,12 @@ def test_internal_purge_index_routes_protocol_request_through_local_gateway( ): payload = { "knowledge_base_id": 7, - "index_owner_user_id": 9, - "retriever_config": { - "name": "retriever-a", - "namespace": "default", - "storage_config": { - "type": "qdrant", - "url": "http://qdrant:6333", - }, - }, + "user_id": 9, } runtime_spec = object() with ( patch( - "app.api.endpoints.internal.rag.runtime_resolver.build_internal_purge_index_runtime_spec", + "app.api.endpoints.internal.rag.runtime_resolver.build_public_purge_index_runtime_spec", return_value=runtime_spec, ) as mock_build_spec, patch( @@ -228,8 +214,8 @@ def test_internal_purge_index_routes_protocol_request_through_local_gateway( mock_build_spec.assert_called_once_with( db=ANY, knowledge_base_id=7, - index_owner_user_id=9, - retriever_config=payload["retriever_config"], + user_id=9, + user_name=None, ) mock_purge.assert_awaited_once_with(runtime_spec, db=ANY) @@ -239,21 +225,12 @@ def test_internal_drop_index_routes_protocol_request_through_local_gateway( ): payload = { "knowledge_base_id": 7, - "index_owner_user_id": 9, - "retriever_config": { - "name": "retriever-a", - "namespace": "default", - "storage_config": { - "type": "qdrant", - "url": "http://qdrant:6333", - "indexStrategy": {"mode": "per_dataset"}, - }, - }, + "user_id": 9, } runtime_spec = object() with ( patch( - "app.api.endpoints.internal.rag.runtime_resolver.build_internal_drop_index_runtime_spec", + "app.api.endpoints.internal.rag.runtime_resolver.build_public_drop_index_runtime_spec", return_value=runtime_spec, ) as mock_build_spec, patch( @@ -280,8 +257,8 @@ def test_internal_drop_index_routes_protocol_request_through_local_gateway( mock_build_spec.assert_called_once_with( db=ANY, knowledge_base_id=7, - index_owner_user_id=9, - retriever_config=payload["retriever_config"], + user_id=9, + user_name=None, ) mock_drop.assert_awaited_once_with(runtime_spec, db=ANY) diff --git a/backend/tests/api/endpoints/test_retrievers_api.py b/backend/tests/api/endpoints/test_retrievers_api.py index 090637252..c9b299b95 100644 --- a/backend/tests/api/endpoints/test_retrievers_api.py +++ b/backend/tests/api/endpoints/test_retrievers_api.py @@ -4,27 +4,23 @@ from unittest.mock import AsyncMock, patch -from app.services.rag.runtime_specs import ConnectionTestRuntimeSpec - def _auth_header(token: str) -> dict[str, str]: return {"Authorization": f"Bearer {token}"} -def test_retriever_test_connection_uses_gateway_runtime_spec( +def test_retriever_test_connection_tests_storage_directly( test_client, test_token: str, ): - gateway = AsyncMock() - gateway.test_connection.return_value = { - "success": True, - "message": "Connection successful", - } + mock_backend = AsyncMock() + # test_connection is called via asyncio.to_thread, so it must be a sync function + mock_backend.test_connection = lambda: True with patch( - "app.api.endpoints.adapter.retrievers.get_query_gateway", - return_value=gateway, - ) as mock_get_gateway: + "app.api.endpoints.adapter.retrievers.create_storage_backend_from_config", + return_value=mock_backend, + ) as mock_create: response = test_client.post( "/api/retrievers/test-connection", headers=_auth_header(test_token), @@ -42,20 +38,15 @@ def test_retriever_test_connection_uses_gateway_runtime_spec( "success": True, "message": "Connection successful", } - mock_get_gateway.assert_called_once() - gateway.test_connection.assert_awaited_once() - - runtime_spec = gateway.test_connection.await_args.args[0] - assert isinstance(runtime_spec, ConnectionTestRuntimeSpec) - assert runtime_spec.retriever_config.storage_config == { - "type": "qdrant", - "url": "http://qdrant:6333", - "username": "alice", - "password": "secret", - "apiKey": "api-token", - "indexStrategy": {"mode": "per_dataset"}, - "ext": {}, - } + mock_create.assert_called_once_with( + storage_type="qdrant", + url="http://qdrant:6333", + username="alice", + password="secret", + api_key="api-token", + index_strategy={"mode": "per_dataset"}, + ext={}, + ) def test_retriever_test_connection_validates_required_fields( diff --git a/backend/tests/services/knowledge/test_orchestrator.py b/backend/tests/services/knowledge/test_orchestrator.py index a8802b1ea..e1a409775 100644 --- a/backend/tests/services/knowledge/test_orchestrator.py +++ b/backend/tests/services/knowledge/test_orchestrator.py @@ -1292,7 +1292,9 @@ def test_list_documents_populates_created_by_from_user_table(self) -> None: mock_db.query.return_value = user_query result = orchestrator.list_documents( - db=mock_db, user=user, knowledge_base_id=10, + db=mock_db, + user=user, + knowledge_base_id=10, ) assert isinstance(result, KnowledgeDocumentListResponse) @@ -1329,7 +1331,9 @@ def test_list_documents_created_by_none_when_user_deleted(self) -> None: mock_db.query.return_value = user_query result = orchestrator.list_documents( - db=mock_db, user=user, knowledge_base_id=10, + db=mock_db, + user=user, + knowledge_base_id=10, ) assert result.items[0].created_by is None @@ -1354,7 +1358,9 @@ def test_list_documents_empty_documents_no_user_query(self) -> None: mock_db = MagicMock() result = orchestrator.list_documents( - db=mock_db, user=user, knowledge_base_id=10, + db=mock_db, + user=user, + knowledge_base_id=10, ) # No db.query should be called since documents list is empty @@ -1391,7 +1397,9 @@ def test_list_documents_deduplicates_user_ids(self) -> None: mock_db.query.return_value = user_query result = orchestrator.list_documents( - db=mock_db, user=user, knowledge_base_id=10, + db=mock_db, + user=user, + knowledge_base_id=10, ) # Both documents should have the same created_by diff --git a/backend/tests/services/rag/test_local_gateway.py b/backend/tests/services/rag/test_local_gateway.py index 1ecee51e5..5bb064073 100644 --- a/backend/tests/services/rag/test_local_gateway.py +++ b/backend/tests/services/rag/test_local_gateway.py @@ -129,18 +129,19 @@ async def test_local_gateway_test_connection_delegates_to_connection_executor(): gateway._connection_test_executor = AsyncMock( return_value={"success": True, "message": "Connection successful"} ) + db = MagicMock() spec = ConnectionTestRuntimeSpec( retriever_config=RuntimeRetrieverConfig( - name="retriever-a", + name="test-retriever", namespace="default", - storage_config={"type": "qdrant"}, + storage_config={"type": "qdrant", "url": "http://localhost:6333"}, ) ) - result = await gateway.test_connection(spec) + result = await gateway.test_connection(spec, db=db) assert result == {"success": True, "message": "Connection successful"} - gateway._connection_test_executor.assert_awaited_once_with(spec, db=None) + gateway._connection_test_executor.assert_awaited_once_with(spec, db=db) @pytest.mark.asyncio diff --git a/backend/tests/services/rag/test_remote_gateway.py b/backend/tests/services/rag/test_remote_gateway.py index 10807cb28..8d9c3ae6c 100644 --- a/backend/tests/services/rag/test_remote_gateway.py +++ b/backend/tests/services/rag/test_remote_gateway.py @@ -18,20 +18,15 @@ from app.services.rag.local_gateway import LocalRagGateway from app.services.rag.remote_gateway import RemoteRagGateway, RemoteRagGatewayError from app.services.rag.runtime_specs import ( - ConnectionTestRuntimeSpec, DeleteRuntimeSpec, DropKnowledgeIndexRuntimeSpec, IndexRuntimeSpec, IndexSource, ListChunksRuntimeSpec, PurgeKnowledgeRuntimeSpec, - QueryKnowledgeBaseRuntimeConfig, QueryRuntimeSpec, - RuntimeEmbeddingModelConfig, - RuntimeRetrievalConfig, - RuntimeRetrieverConfig, ) -from shared.models import PresignedUrlContentRef +from shared.models import PresignedUrlContentRef, RuntimeRetrieverConfig def _build_response( @@ -45,7 +40,9 @@ def _build_response( @pytest.mark.asyncio -async def test_remote_gateway_index_document_posts_runtime_request(mocker) -> None: +async def test_remote_gateway_index_document_posts_reference_mode_request( + mocker, +) -> None: db = MagicMock() mocker.patch( "app.services.rag.remote_gateway.build_content_ref_for_attachment", @@ -79,27 +76,6 @@ async def test_remote_gateway_index_document_posts_runtime_request(mocker) -> No embedding_model_name="embedding-a", embedding_model_namespace="default", source=IndexSource(source_type="attachment", attachment_id=9), - retriever_config=RuntimeRetrieverConfig( - name="retriever-a", - namespace="default", - storage_config={ - "type": "qdrant", - "url": "http://qdrant:6333", - "indexStrategy": {"mode": "per_dataset"}, - }, - ), - embedding_model_config=RuntimeEmbeddingModelConfig( - model_name="embedding-a", - model_namespace="default", - resolved_config={ - "protocol": "openai", - "model_id": "text-embedding-3-small", - "base_url": "https://api.openai.com/v1", - }, - ), - splitter_config={"type": "sentence"}, - index_families=["chunk_vector", "summary_vector_index"], - user_name="alice", ) result = await gateway.index_document(spec, db=db) @@ -110,50 +86,19 @@ async def test_remote_gateway_index_document_posts_runtime_request(mocker) -> No assert args[0] == "http://knowledge-runtime/internal/rag/index" assert kwargs["json"] == { "knowledge_base_id": 1, + "user_id": 3, "document_id": 2, - "index_owner_user_id": 3, - "retriever_config": { - "name": "retriever-a", - "namespace": "default", - "storage_config": { - "type": "qdrant", - "url": "http://qdrant:6333", - "indexStrategy": {"mode": "per_dataset"}, - }, - }, - "embedding_model_config": { - "model_name": "embedding-a", - "model_namespace": "default", - "resolved_config": { - "protocol": "openai", - "model_id": "text-embedding-3-small", - "base_url": "https://api.openai.com/v1", - }, - }, - "splitter_config": { - "chunk_strategy": "flat", - "format_enhancement": "none", - "flat_config": { - "chunk_size": 1024, - "chunk_overlap": 200, - "separator": "\n\n", - }, - "markdown_enhancement": {"enabled": False}, - "legacy_type": "sentence", - }, - "index_families": ["chunk_vector", "summary_vector_index"], + "source_file": "release-notes.md", + "file_extension": ".md", "content_ref": { "kind": "presigned_url", "url": "https://storage.example.com/release-notes.md", }, - "source_file": "release-notes.md", - "file_extension": ".md", - "user_name": "alice", } @pytest.mark.asyncio -async def test_remote_gateway_query_omits_backend_only_route_fields(mocker) -> None: +async def test_remote_gateway_query_posts_reference_mode_request(mocker) -> None: post_mock = mocker.patch( "httpx.AsyncClient.post", return_value=_build_response( @@ -179,38 +124,10 @@ async def test_remote_gateway_query_omits_backend_only_route_fields(mocker) -> N spec = QueryRuntimeSpec( knowledge_base_ids=[1], query="release checklist", - route_mode="direct_injection", - restricted_mode=True, user_id=8, - knowledge_base_configs=[ - QueryKnowledgeBaseRuntimeConfig( - knowledge_base_id=1, - index_owner_user_id=8, - retriever_config=RuntimeRetrieverConfig( - name="retriever-a", - namespace="default", - storage_config={ - "type": "qdrant", - "url": "http://qdrant:6333", - }, - ), - embedding_model_config=RuntimeEmbeddingModelConfig( - model_name="embedding-a", - model_namespace="default", - resolved_config={ - "protocol": "openai", - "model_id": "text-embedding-3-small", - }, - ), - retrieval_config=RuntimeRetrievalConfig( - top_k=20, - score_threshold=0.7, - retrieval_mode="vector", - ), - ) - ], - enabled_index_families=["chunk_vector", "summary_vector_index"], - retrieval_policy="summary_then_chunk_expand", + max_results=5, + document_ids=[10, 11], + metadata_condition={"key": "source", "operator": "==", "value": "kb"}, ) result = await gateway.query(spec) @@ -235,37 +152,15 @@ async def test_remote_gateway_query_omits_backend_only_route_fields(mocker) -> N assert args[0] == "http://knowledge-runtime/internal/rag/query" assert kwargs["json"] == { "knowledge_base_ids": [1], + "user_id": 8, "query": "release checklist", "max_results": 5, - "knowledge_base_configs": [ - { - "knowledge_base_id": 1, - "index_owner_user_id": 8, - "retriever_config": { - "name": "retriever-a", - "namespace": "default", - "storage_config": { - "type": "qdrant", - "url": "http://qdrant:6333", - }, - }, - "embedding_model_config": { - "model_name": "embedding-a", - "model_namespace": "default", - "resolved_config": { - "protocol": "openai", - "model_id": "text-embedding-3-small", - }, - }, - "retrieval_config": { - "top_k": 20, - "score_threshold": 0.7, - "retrieval_mode": "vector", - }, - } - ], - "enabled_index_families": ["chunk_vector", "summary_vector_index"], - "retrieval_policy": "summary_then_chunk_expand", + "document_ids": [10, 11], + "metadata_condition": { + "key": "source", + "operator": "==", + "value": "kb", + }, } @@ -294,33 +189,7 @@ async def test_remote_gateway_translates_structured_remote_errors(mocker) -> Non QueryRuntimeSpec( knowledge_base_ids=[1], query="release", - knowledge_base_configs=[ - QueryKnowledgeBaseRuntimeConfig( - knowledge_base_id=1, - index_owner_user_id=8, - retriever_config=RuntimeRetrieverConfig( - name="retriever-a", - namespace="default", - storage_config={ - "type": "qdrant", - "url": "http://qdrant:6333", - }, - ), - embedding_model_config=RuntimeEmbeddingModelConfig( - model_name="embedding-a", - model_namespace="default", - resolved_config={ - "protocol": "openai", - "model_id": "text-embedding-3-small", - }, - ), - retrieval_config=RuntimeRetrievalConfig( - top_k=20, - score_threshold=0.7, - retrieval_mode="vector", - ), - ) - ], + user_id=8, ) ) @@ -349,33 +218,7 @@ async def test_remote_gateway_wraps_transport_errors(mocker) -> None: QueryRuntimeSpec( knowledge_base_ids=[1], query="release", - knowledge_base_configs=[ - QueryKnowledgeBaseRuntimeConfig( - knowledge_base_id=1, - index_owner_user_id=8, - retriever_config=RuntimeRetrieverConfig( - name="retriever-a", - namespace="default", - storage_config={ - "type": "qdrant", - "url": "http://qdrant:6333", - }, - ), - embedding_model_config=RuntimeEmbeddingModelConfig( - model_name="embedding-a", - model_namespace="default", - resolved_config={ - "protocol": "openai", - "model_id": "text-embedding-3-small", - }, - ), - retrieval_config=RuntimeRetrievalConfig( - top_k=20, - score_threshold=0.7, - retrieval_mode="vector", - ), - ) - ], + user_id=8, ) ) @@ -385,7 +228,7 @@ async def test_remote_gateway_wraps_transport_errors(mocker) -> None: @pytest.mark.asyncio -async def test_remote_gateway_delete_posts_resolved_retriever_config(mocker) -> None: +async def test_remote_gateway_delete_posts_reference_mode_request(mocker) -> None: post_mock = mocker.patch( "httpx.AsyncClient.post", return_value=_build_response( @@ -404,11 +247,7 @@ async def test_remote_gateway_delete_posts_resolved_retriever_config(mocker) -> retriever_config=RuntimeRetrieverConfig( name="retriever-a", namespace="default", - storage_config={ - "type": "elasticsearch", - "url": "http://es:9200", - "indexStrategy": {"mode": "per_user"}, - }, + storage_config={"type": "elasticsearch"}, ), ) @@ -419,23 +258,15 @@ async def test_remote_gateway_delete_posts_resolved_retriever_config(mocker) -> assert args[0] == "http://knowledge-runtime/internal/rag/delete-document-index" assert kwargs["json"] == { "knowledge_base_id": 1, + "user_id": 7, "document_ref": "9", - "index_owner_user_id": 7, - "retriever_config": { - "name": "retriever-a", - "namespace": "default", - "storage_config": { - "type": "elasticsearch", - "url": "http://es:9200", - "indexStrategy": {"mode": "per_user"}, - }, - }, - "enabled_index_families": ["chunk_vector"], } @pytest.mark.asyncio -async def test_remote_gateway_purge_index_posts_runtime_request(mocker) -> None: +async def test_remote_gateway_purge_index_posts_reference_mode_request( + mocker, +) -> None: post_mock = mocker.patch( "httpx.AsyncClient.post", return_value=_build_response( @@ -453,11 +284,7 @@ async def test_remote_gateway_purge_index_posts_runtime_request(mocker) -> None: retriever_config=RuntimeRetrieverConfig( name="retriever-a", namespace="default", - storage_config={ - "type": "elasticsearch", - "url": "http://es:9200", - "indexStrategy": {"mode": "per_user"}, - }, + storage_config={"type": "elasticsearch"}, ), ) @@ -468,21 +295,12 @@ async def test_remote_gateway_purge_index_posts_runtime_request(mocker) -> None: assert args[0] == "http://knowledge-runtime/internal/rag/purge-knowledge-index" assert kwargs["json"] == { "knowledge_base_id": 1, - "index_owner_user_id": 7, - "retriever_config": { - "name": "retriever-a", - "namespace": "default", - "storage_config": { - "type": "elasticsearch", - "url": "http://es:9200", - "indexStrategy": {"mode": "per_user"}, - }, - }, + "user_id": 7, } @pytest.mark.asyncio -async def test_remote_gateway_drop_index_posts_runtime_request(mocker) -> None: +async def test_remote_gateway_drop_index_posts_reference_mode_request(mocker) -> None: post_mock = mocker.patch( "httpx.AsyncClient.post", return_value=_build_response( @@ -500,11 +318,7 @@ async def test_remote_gateway_drop_index_posts_runtime_request(mocker) -> None: retriever_config=RuntimeRetrieverConfig( name="retriever-a", namespace="default", - storage_config={ - "type": "elasticsearch", - "url": "http://es:9200", - "indexStrategy": {"mode": "per_dataset"}, - }, + storage_config={"type": "elasticsearch"}, ), ) @@ -515,61 +329,12 @@ async def test_remote_gateway_drop_index_posts_runtime_request(mocker) -> None: assert args[0] == "http://knowledge-runtime/internal/rag/drop-knowledge-index" assert kwargs["json"] == { "knowledge_base_id": 1, - "index_owner_user_id": 7, - "retriever_config": { - "name": "retriever-a", - "namespace": "default", - "storage_config": { - "type": "elasticsearch", - "url": "http://es:9200", - "indexStrategy": {"mode": "per_dataset"}, - }, - }, + "user_id": 7, } @pytest.mark.asyncio -async def test_remote_gateway_test_connection_posts_resolved_retriever_config( - mocker, -) -> None: - post_mock = mocker.patch( - "httpx.AsyncClient.post", - return_value=_build_response( - url="http://knowledge-runtime/internal/rag/test-connection", - status_code=200, - json_body={"success": True, "message": "Connection successful"}, - ), - ) - gateway = RemoteRagGateway( - base_url="http://knowledge-runtime", - ) - spec = ConnectionTestRuntimeSpec( - retriever_config=RuntimeRetrieverConfig( - name="retriever-a", - namespace="default", - storage_config={"type": "elasticsearch", "url": "http://es:9200"}, - ) - ) - - result = await gateway.test_connection(spec, db=MagicMock()) - - assert result == {"success": True, "message": "Connection successful"} - args, kwargs = post_mock.await_args - assert args[0] == "http://knowledge-runtime/internal/rag/test-connection" - assert kwargs["json"] == { - "retriever_config": { - "name": "retriever-a", - "namespace": "default", - "storage_config": { - "type": "elasticsearch", - "url": "http://es:9200", - }, - } - } - - -@pytest.mark.asyncio -async def test_remote_gateway_list_chunks_posts_runtime_request(mocker) -> None: +async def test_remote_gateway_list_chunks_posts_reference_mode_request(mocker) -> None: post_mock = mocker.patch( "httpx.AsyncClient.post", return_value=_build_response( @@ -598,10 +363,7 @@ async def test_remote_gateway_list_chunks_posts_runtime_request(mocker) -> None: retriever_config=RuntimeRetrieverConfig( name="retriever-a", namespace="default", - storage_config={ - "type": "qdrant", - "url": "http://qdrant:6333", - }, + storage_config={"type": "qdrant"}, ), max_chunks=1000, query="list_index_chunks", @@ -631,15 +393,7 @@ async def test_remote_gateway_list_chunks_posts_runtime_request(mocker) -> None: assert args[0] == "http://knowledge-runtime/internal/rag/all-chunks" assert kwargs["json"] == { "knowledge_base_id": 1, - "index_owner_user_id": 8, - "retriever_config": { - "name": "retriever-a", - "namespace": "default", - "storage_config": { - "type": "qdrant", - "url": "http://qdrant:6333", - }, - }, + "user_id": 8, "max_chunks": 1000, "query": "list_index_chunks", "metadata_condition": { @@ -702,23 +456,7 @@ async def test_gateway_adds_auth_header_when_token_configured(mocker) -> None: spec = QueryRuntimeSpec( knowledge_base_ids=[1], query="test", - knowledge_base_configs=[ - QueryKnowledgeBaseRuntimeConfig( - knowledge_base_id=1, - index_owner_user_id=1, - retriever_config=RuntimeRetrieverConfig( - name="test", - namespace="default", - storage_config={"type": "elasticsearch", "url": "http://es:9200"}, - ), - embedding_model_config=RuntimeEmbeddingModelConfig( - model_name="test", - model_namespace="default", - resolved_config={"protocol": "openai", "model_id": "test-model"}, - ), - retrieval_config=RuntimeRetrievalConfig(), - ) - ], + user_id=1, ) await gateway.query(spec) @@ -749,23 +487,7 @@ async def test_gateway_no_auth_header_when_token_empty(mocker, monkeypatch) -> N spec = QueryRuntimeSpec( knowledge_base_ids=[1], query="test", - knowledge_base_configs=[ - QueryKnowledgeBaseRuntimeConfig( - knowledge_base_id=1, - index_owner_user_id=1, - retriever_config=RuntimeRetrieverConfig( - name="test", - namespace="default", - storage_config={"type": "elasticsearch", "url": "http://es:9200"}, - ), - embedding_model_config=RuntimeEmbeddingModelConfig( - model_name="test", - model_namespace="default", - resolved_config={"protocol": "openai", "model_id": "test-model"}, - ), - retrieval_config=RuntimeRetrievalConfig(), - ) - ], + user_id=1, ) await gateway.query(spec) @@ -797,23 +519,7 @@ async def test_gateway_uses_settings_token_when_not_provided( spec = QueryRuntimeSpec( knowledge_base_ids=[1], query="test", - knowledge_base_configs=[ - QueryKnowledgeBaseRuntimeConfig( - knowledge_base_id=1, - index_owner_user_id=1, - retriever_config=RuntimeRetrieverConfig( - name="test", - namespace="default", - storage_config={"type": "elasticsearch", "url": "http://es:9200"}, - ), - embedding_model_config=RuntimeEmbeddingModelConfig( - model_name="test", - model_namespace="default", - resolved_config={"protocol": "openai", "model_id": "test-model"}, - ), - retrieval_config=RuntimeRetrievalConfig(), - ) - ], + user_id=1, ) await gateway.query(spec) diff --git a/backend/tests/services/rag/test_retrieval_service.py b/backend/tests/services/rag/test_retrieval_service.py index 465f5a964..77e228f99 100644 --- a/backend/tests/services/rag/test_retrieval_service.py +++ b/backend/tests/services/rag/test_retrieval_service.py @@ -44,10 +44,10 @@ async def test_retrieve_for_chat_shell_no_longer_persists_subtask_context(): mock_update_context = MagicMock() with patch.multiple( - context_service, - get_knowledge_base_context_map_by_subtask=mock_get_context_map, - create_knowledge_base_context_with_result=mock_create_context, - update_knowledge_base_retrieval_result=mock_update_context, + context_service, + get_knowledge_base_context_map_by_subtask=mock_get_context_map, + create_knowledge_base_context_with_result=mock_create_context, + update_knowledge_base_retrieval_result=mock_update_context, ): result = await service.retrieve_with_routing( query="test", @@ -98,7 +98,7 @@ async def test_get_all_chunks_without_user_auth_check(self): with ( patch( - "app.services.rag.retrieval_service.RagRuntimeResolver.build_internal_list_chunks_runtime_spec", + "app.services.rag.retrieval_service.RagRuntimeResolver.build_public_list_chunks_runtime_spec", return_value=spec, ) as mock_build_spec, patch.object( @@ -111,6 +111,7 @@ async def test_get_all_chunks_without_user_auth_check(self): result = await RetrievalService().get_all_chunks_from_knowledge_base( knowledge_base_id=123, db=db_session, + user_id=42, max_chunks=50, query="debug query", ) @@ -127,6 +128,8 @@ async def test_get_all_chunks_without_user_auth_check(self): mock_build_spec.assert_called_once_with( db=db_session, knowledge_base_id=123, + user_id=42, + user_name=None, max_chunks=50, query="debug query", metadata_condition=None, @@ -182,9 +185,9 @@ async def test_auto_route_returns_direct_injection_records(self): db = MagicMock() with patch.object( - RetrievalService, - "_estimate_total_tokens_for_knowledge_bases", - return_value=100, + RetrievalService, + "_estimate_total_tokens_for_knowledge_bases", + return_value=100, ) as mock_estimate: service = RetrievalService() service.get_original_documents_from_knowledge_base = AsyncMock( @@ -220,7 +223,7 @@ async def test_auto_route_returns_direct_injection_records(self): @pytest.mark.asyncio async def test_auto_route_falls_back_to_rag_when_runtime_budget_is_insufficient( - self, + self, ): """Backend should own the final fit check when runtime budget is provided.""" from app.services.rag.retrieval_service import RetrievalService @@ -228,9 +231,9 @@ async def test_auto_route_falls_back_to_rag_when_runtime_budget_is_insufficient( db = MagicMock() with patch.object( - RetrievalService, - "_estimate_total_tokens_for_knowledge_bases", - return_value=100, + RetrievalService, + "_estimate_total_tokens_for_knowledge_bases", + return_value=100, ): service = RetrievalService() service.get_original_documents_from_knowledge_base = AsyncMock( @@ -343,9 +346,9 @@ async def test_auto_route_estimates_only_filtered_documents(self): ) with patch.object( - RetrievalService, - "_estimate_total_tokens_for_knowledge_bases", - return_value=100, + RetrievalService, + "_estimate_total_tokens_for_knowledge_bases", + return_value=100, ) as mock_estimate: result = await service.retrieve_with_routing( query="test", @@ -366,7 +369,7 @@ async def test_auto_route_estimates_only_filtered_documents(self): assert result["records"][0]["knowledge_base_id"] == 123 def test_decide_route_mode_for_chat_shell_returns_rag_retrieval_without_budget( - self, + self, ): from app.services.rag.retrieval_service import RetrievalService @@ -384,7 +387,7 @@ def test_decide_route_mode_for_chat_shell_returns_rag_retrieval_without_budget( assert result == "rag_retrieval" def test_decide_route_mode_for_chat_shell_returns_direct_injection_when_auto_fits( - self, + self, ): from app.services.rag.retrieval_service import RetrievalService @@ -392,9 +395,9 @@ def test_decide_route_mode_for_chat_shell_returns_direct_injection_when_auto_fit db = MagicMock() with patch.object( - RetrievalService, - "_estimate_total_tokens_for_knowledge_bases", - return_value=100, + RetrievalService, + "_estimate_total_tokens_for_knowledge_bases", + return_value=100, ) as mock_estimate: result = service.decide_route_mode_for_chat_shell( query="test", @@ -413,7 +416,7 @@ def test_decide_route_mode_for_chat_shell_returns_direct_injection_when_auto_fit assert result == "direct_injection" def test_decide_route_mode_for_chat_shell_skips_direct_injection_when_auto_disabled( - self, monkeypatch + self, monkeypatch ): from app.core.config import settings from app.services.rag.retrieval_service import RetrievalService @@ -429,9 +432,9 @@ def test_decide_route_mode_for_chat_shell_skips_direct_injection_when_auto_disab db = MagicMock() with patch.object( - RetrievalService, - "_estimate_total_tokens_for_knowledge_bases", - return_value=100, + RetrievalService, + "_estimate_total_tokens_for_knowledge_bases", + return_value=100, ) as mock_estimate: result = service.decide_route_mode_for_chat_shell( query="test", @@ -451,9 +454,9 @@ def test_decide_route_mode_for_chat_shell_uses_live_runtime_budget(self): db = MagicMock() with patch.object( - RetrievalService, - "_estimate_total_tokens_for_knowledge_bases", - return_value=100, + RetrievalService, + "_estimate_total_tokens_for_knowledge_bases", + return_value=100, ): result = service.decide_route_mode_for_chat_shell( query="test", @@ -469,7 +472,7 @@ def test_decide_route_mode_for_chat_shell_uses_live_runtime_budget(self): assert result == "rag_retrieval" def test_decide_route_mode_for_chat_shell_forces_rag_when_metadata_filter_exists( - self, + self, ): from app.services.rag.retrieval_service import RetrievalService @@ -696,7 +699,7 @@ async def test_force_rag_route_sorts_and_limits_results_globally(self): @pytest.mark.asyncio async def test_force_rag_route_uses_knowledge_engine_query_executor_when_runtime_configs_are_available( - self, + self, ): """Resolved runtime configs should drive the engine query seam in local mode.""" from app.services.rag.retrieval_service import RetrievalService diff --git a/docker-compose.yml b/docker-compose.yml index 239cb8b07..e1f007644 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -284,6 +284,8 @@ services: - LOG_LEVEL=${LOG_LEVEL:-INFO} # Internal service token (uncomment to enable, must match Backend's INTERNAL_SERVICE_TOKEN) # - INTERNAL_SERVICE_TOKEN=${INTERNAL_SERVICE_TOKEN:-your-secure-token-here} + # Database connection for config resolution (reference-mode) + - DATABASE_URL=mysql+pymysql://${MYSQL_USER:-task_user}:${MYSQL_PASSWORD:-task_password}@mysql:3306/${MYSQL_DATABASE:-task_manager} # OpenTelemetry Configuration (uncomment to enable) # - OTEL_ENABLED=true # - OTEL_SERVICE_NAME=wegent-knowledge-runtime diff --git a/knowledge_runtime/knowledge_runtime/api/endpoints/admin.py b/knowledge_runtime/knowledge_runtime/api/endpoints/admin.py index ab4d24806..703b429ae 100644 --- a/knowledge_runtime/knowledge_runtime/api/endpoints/admin.py +++ b/knowledge_runtime/knowledge_runtime/api/endpoints/admin.py @@ -8,16 +8,17 @@ from typing import Any -from fastapi import APIRouter +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session from knowledge_runtime.services.admin_executor import AdminExecutor +from shared.db.sync_session import get_db from shared.models import ( RemoteDeleteDocumentIndexRequest, RemoteDropKnowledgeIndexRequest, RemoteListChunksRequest, RemoteListChunksResponse, RemotePurgeKnowledgeIndexRequest, - RemoteTestConnectionRequest, ) router = APIRouter() @@ -26,74 +27,38 @@ @router.post("/delete-document-index") async def delete_document_index( request: RemoteDeleteDocumentIndexRequest, + db: Session = Depends(get_db), ) -> dict[str, Any]: - """Delete a document's index from a knowledge base. - - Args: - request: The delete request. - - Returns: - Deletion result. - """ - executor = AdminExecutor() + """Delete a document's index from a knowledge base.""" + executor = AdminExecutor(db=db) return await executor.delete_document_index(request) @router.post("/purge-knowledge-index") async def purge_knowledge_index( request: RemotePurgeKnowledgeIndexRequest, + db: Session = Depends(get_db), ) -> dict[str, Any]: - """Delete all chunks for a knowledge base. - - Args: - request: The purge request. - - Returns: - Purge result. - """ - executor = AdminExecutor() + """Delete all chunks for a knowledge base.""" + executor = AdminExecutor(db=db) return await executor.purge_knowledge_index(request) @router.post("/drop-knowledge-index") async def drop_knowledge_index( request: RemoteDropKnowledgeIndexRequest, + db: Session = Depends(get_db), ) -> dict[str, Any]: - """Physically drop the index/collection for a knowledge base. - - Args: - request: The drop request. - - Returns: - Drop result. - """ - executor = AdminExecutor() + """Physically drop the index/collection for a knowledge base.""" + executor = AdminExecutor(db=db) return await executor.drop_knowledge_index(request) @router.post("/all-chunks") -async def list_chunks(request: RemoteListChunksRequest) -> RemoteListChunksResponse: - """List all chunks in a knowledge base. - - Args: - request: The list request. - - Returns: - List of chunks. - """ - executor = AdminExecutor() +async def list_chunks( + request: RemoteListChunksRequest, + db: Session = Depends(get_db), +) -> RemoteListChunksResponse: + """List all chunks in a knowledge base.""" + executor = AdminExecutor(db=db) return await executor.list_chunks(request) - - -@router.post("/test-connection") -async def test_connection(request: RemoteTestConnectionRequest) -> dict[str, Any]: - """Test connection to a storage backend. - - Args: - request: The test request. - - Returns: - Connection test result. - """ - executor = AdminExecutor() - return await executor.test_connection(request) diff --git a/knowledge_runtime/knowledge_runtime/api/endpoints/index.py b/knowledge_runtime/knowledge_runtime/api/endpoints/index.py index 101c082a0..e84797ad0 100644 --- a/knowledge_runtime/knowledge_runtime/api/endpoints/index.py +++ b/knowledge_runtime/knowledge_runtime/api/endpoints/index.py @@ -8,28 +8,29 @@ from typing import Any -from fastapi import APIRouter +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session from knowledge_runtime.services.index_executor import IndexExecutor +from shared.db.sync_session import get_db from shared.models import RemoteIndexRequest router = APIRouter() @router.post("/index") -async def index_document(request: RemoteIndexRequest) -> dict[str, Any]: +async def index_document( + request: RemoteIndexRequest, + db: Session = Depends(get_db), +) -> dict[str, Any]: """Index a document for RAG retrieval. - This endpoint: - 1. Fetches content from the provided ContentRef - 2. Creates storage backend and embedding model from configs - 3. Indexes the document chunks into the vector store - Args: - request: The index request containing content reference and configs. + request: The index request containing knowledge_base_id and content_ref. + db: Database session for config resolution. Returns: Indexing result with chunk_count, doc_ref, etc. """ - executor = IndexExecutor() + executor = IndexExecutor(db=db) return await executor.execute(request) diff --git a/knowledge_runtime/knowledge_runtime/api/endpoints/query.py b/knowledge_runtime/knowledge_runtime/api/endpoints/query.py index 3d0cf95e3..61cffe0d0 100644 --- a/knowledge_runtime/knowledge_runtime/api/endpoints/query.py +++ b/knowledge_runtime/knowledge_runtime/api/endpoints/query.py @@ -6,28 +6,29 @@ from __future__ import annotations -from fastapi import APIRouter +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session from knowledge_runtime.services.query_executor import QueryExecutor +from shared.db.sync_session import get_db from shared.models import RemoteQueryRequest, RemoteQueryResponse router = APIRouter() @router.post("/query") -async def query_documents(request: RemoteQueryRequest) -> RemoteQueryResponse: +async def query_documents( + request: RemoteQueryRequest, + db: Session = Depends(get_db), +) -> RemoteQueryResponse: """Query documents for RAG retrieval. - This endpoint: - 1. Creates storage backends and embedding models for each KB config - 2. Executes the query against each knowledge base - 3. Aggregates and ranks results by score - Args: - request: The query request containing query text and KB configs. + request: The query request containing query text and knowledge_base_ids. + db: Database session for config resolution. Returns: Query response with ranked records. """ - executor = QueryExecutor() + executor = QueryExecutor(db=db) return await executor.execute(request) diff --git a/knowledge_runtime/knowledge_runtime/config.py b/knowledge_runtime/knowledge_runtime/config.py index 80acf3450..879f073d5 100644 --- a/knowledge_runtime/knowledge_runtime/config.py +++ b/knowledge_runtime/knowledge_runtime/config.py @@ -51,6 +51,12 @@ class Settings(BaseSettings): # Generate using: openssl rand -hex 32 internal_service_token: str = "" + # Database connection for config resolution + database_url: str = Field( + default="", + validation_alias=AliasChoices("KNOWLEDGE_RUNTIME_DATABASE_URL", "DATABASE_URL"), + ) + # Global settings instance _settings: Settings | None = None diff --git a/knowledge_runtime/knowledge_runtime/main.py b/knowledge_runtime/knowledge_runtime/main.py index a35ea3709..78603d3e0 100644 --- a/knowledge_runtime/knowledge_runtime/main.py +++ b/knowledge_runtime/knowledge_runtime/main.py @@ -33,6 +33,15 @@ async def lifespan(app: FastAPI): log_level=settings.log_level, ) + # Initialize database connection for config resolution + if settings.database_url: + from shared.db.sync_session import init_db + + init_db(settings.database_url) + logger.info("Database initialized for config resolution") + else: + logger.warning("DATABASE_URL not configured - config resolution will not work") + logger.info( f"knowledge_runtime starting on {settings.host}:{settings.port}", ) diff --git a/knowledge_runtime/knowledge_runtime/models/__init__.py b/knowledge_runtime/knowledge_runtime/models/__init__.py new file mode 100644 index 000000000..defb94adf --- /dev/null +++ b/knowledge_runtime/knowledge_runtime/models/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2026 Weibo, Inc. +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/knowledge_runtime/knowledge_runtime/models/knowledge_document.py b/knowledge_runtime/knowledge_runtime/models/knowledge_document.py new file mode 100644 index 000000000..0099a28ee --- /dev/null +++ b/knowledge_runtime/knowledge_runtime/models/knowledge_document.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: 2026 Weibo, Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Minimal KnowledgeDocument model for knowledge_runtime. + +Only the fields needed by KR for config resolution are defined here. +The full model lives in the Backend module with additional fields +and enums not required by KR. +""" + +from __future__ import annotations + +from sqlalchemy import JSON, Column, Integer + +from shared.models.db.base import Base + + +class KnowledgeDocument(Base): + """Minimal model for knowledge_documents table (KR only needs 3 fields).""" + + __tablename__ = "knowledge_documents" + + id = Column(Integer, primary_key=True, index=True) + attachment_id = Column(Integer, nullable=False, default=0) + splitter_config = Column(JSON, nullable=False, default={}) diff --git a/knowledge_runtime/knowledge_runtime/services/admin_executor.py b/knowledge_runtime/knowledge_runtime/services/admin_executor.py index eb6309371..207af66d0 100644 --- a/knowledge_runtime/knowledge_runtime/services/admin_executor.py +++ b/knowledge_runtime/knowledge_runtime/services/admin_executor.py @@ -10,8 +10,14 @@ import logging from typing import Any +from sqlalchemy.orm import Session + from knowledge_engine.services.document_service import DocumentService from knowledge_engine.storage.factory import create_storage_backend_from_runtime_config +from knowledge_runtime.services.config_resolver import ( + AdminResolvedConfig, + ConfigResolver, +) from shared.models import ( RemoteDeleteDocumentIndexRequest, RemoteDropKnowledgeIndexRequest, @@ -19,7 +25,6 @@ RemoteListChunksRequest, RemoteListChunksResponse, RemotePurgeKnowledgeIndexRequest, - RemoteTestConnectionRequest, ) from shared.telemetry.decorators import trace_async @@ -37,6 +42,10 @@ class AdminExecutor: - test_connection: Test storage backend connection """ + def __init__(self, db: Session) -> None: + self._db = db + self._config_resolver = ConfigResolver() + @trace_async( span_name="delete_document_index", tracer_name="knowledge_runtime.services.admin", @@ -45,30 +54,28 @@ async def delete_document_index( self, request: RemoteDeleteDocumentIndexRequest, ) -> dict[str, Any]: - """Delete a document's index from a knowledge base. - - Args: - request: The delete request. + """Delete a document's index from a knowledge base.""" + config = self._config_resolver.resolve_admin_config( + self._db, + knowledge_base_id=request.knowledge_base_id, + ) - Returns: - Deletion result. - """ storage_backend = create_storage_backend_from_runtime_config( - request.retriever_config + config.retriever_config ) - knowledge_id = str(request.knowledge_base_id) logger.info( - f"Deleting document index: knowledge_base_id={request.knowledge_base_id}, " - f"doc_ref={request.document_ref}" + "Deleting document index: knowledge_base_id=%d, doc_ref=%s", + request.knowledge_base_id, + request.document_ref, ) result = await asyncio.to_thread( storage_backend.delete_document, knowledge_id=knowledge_id, doc_ref=request.document_ref, - user_id=request.index_owner_user_id, + user_id=config.index_owner_user_id, ) return result @@ -81,28 +88,26 @@ async def purge_knowledge_index( self, request: RemotePurgeKnowledgeIndexRequest, ) -> dict[str, Any]: - """Delete all chunks for a knowledge base. - - Args: - request: The purge request. + """Delete all chunks for a knowledge base.""" + config = self._config_resolver.resolve_admin_config( + self._db, + knowledge_base_id=request.knowledge_base_id, + ) - Returns: - Purge result. - """ storage_backend = create_storage_backend_from_runtime_config( - request.retriever_config + config.retriever_config ) - knowledge_id = str(request.knowledge_base_id) logger.info( - f"Purging knowledge base index: knowledge_base_id={request.knowledge_base_id}" + "Purging knowledge base index: knowledge_base_id=%d", + request.knowledge_base_id, ) result = await asyncio.to_thread( storage_backend.delete_knowledge, knowledge_id=knowledge_id, - user_id=request.index_owner_user_id, + user_id=config.index_owner_user_id, ) return result @@ -115,28 +120,26 @@ async def drop_knowledge_index( self, request: RemoteDropKnowledgeIndexRequest, ) -> dict[str, Any]: - """Physically drop the index/collection for a knowledge base. - - Args: - request: The drop request. + """Physically drop the index/collection for a knowledge base.""" + config = self._config_resolver.resolve_admin_config( + self._db, + knowledge_base_id=request.knowledge_base_id, + ) - Returns: - Drop result. - """ storage_backend = create_storage_backend_from_runtime_config( - request.retriever_config + config.retriever_config ) - knowledge_id = str(request.knowledge_base_id) logger.info( - f"Dropping knowledge base index: knowledge_base_id={request.knowledge_base_id}" + "Dropping knowledge base index: knowledge_base_id=%d", + request.knowledge_base_id, ) result = await asyncio.to_thread( storage_backend.drop_knowledge_index, knowledge_id=knowledge_id, - user_id=request.index_owner_user_id, + user_id=config.index_owner_user_id, ) return result @@ -149,18 +152,15 @@ async def list_chunks( self, request: RemoteListChunksRequest, ) -> RemoteListChunksResponse: - """List all chunks in a knowledge base. - - Args: - request: The list request. + """List all chunks in a knowledge base.""" + config = self._config_resolver.resolve_admin_config( + self._db, + knowledge_base_id=request.knowledge_base_id, + ) - Returns: - List of chunks. - """ storage_backend = create_storage_backend_from_runtime_config( - request.retriever_config + config.retriever_config ) - knowledge_id = str(request.knowledge_base_id) chunks = await asyncio.to_thread( @@ -168,7 +168,7 @@ async def list_chunks( knowledge_id=knowledge_id, max_chunks=request.max_chunks, metadata_condition=request.metadata_condition, - user_id=request.index_owner_user_id, + user_id=config.index_owner_user_id, ) records = [ @@ -183,46 +183,13 @@ async def list_chunks( ] logger.info( - f"Listed chunks: knowledge_base_id={request.knowledge_base_id}, " - f"count={len(records)}, max_chunks={request.max_chunks}" + "Listed chunks: knowledge_base_id=%d, count=%d, max_chunks=%d", + request.knowledge_base_id, + len(records), + request.max_chunks, ) return RemoteListChunksResponse( chunks=records, total=len(records), ) - - @trace_async( - span_name="test_connection", - tracer_name="knowledge_runtime.services.admin", - ) - async def test_connection( - self, - request: RemoteTestConnectionRequest, - ) -> dict[str, Any]: - """Test connection to a storage backend. - - Args: - request: The test request. - - Returns: - Connection test result. - """ - storage_backend = create_storage_backend_from_runtime_config( - request.retriever_config - ) - - logger.info("Testing storage backend connection") - - try: - success = await asyncio.to_thread(storage_backend.test_connection) - return { - "success": success, - "message": "Connection successful" if success else "Connection failed", - } - except Exception as e: - logger.error(f"Connection test failed: {e}") - return { - "success": False, - "message": str(e), - } diff --git a/knowledge_runtime/knowledge_runtime/services/config_resolver.py b/knowledge_runtime/knowledge_runtime/services/config_resolver.py new file mode 100644 index 000000000..42e62a344 --- /dev/null +++ b/knowledge_runtime/knowledge_runtime/services/config_resolver.py @@ -0,0 +1,441 @@ +# SPDX-FileCopyrightText: 2026 Weibo, Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Config resolver for knowledge_runtime. + +Resolves runtime configurations (retriever, embedding model, splitter) +from the database using knowledge_base_id and user_id as references. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import Any + +from sqlalchemy.orm import Session + +from knowledge_runtime.models.knowledge_document import KnowledgeDocument +from shared.models import ( + RuntimeEmbeddingModelConfig, + RuntimeRetrievalConfig, + RuntimeRetrieverConfig, +) +from shared.models.db import Kind, User +from shared.utils.crypto import decrypt_api_key +from shared.utils.placeholder import process_custom_headers_placeholders + +logger = logging.getLogger(__name__) + + +@dataclass +class IndexConfig: + """Resolved configuration for document indexing.""" + + index_owner_user_id: int + retriever_config: RuntimeRetrieverConfig + embedding_model_config: RuntimeEmbeddingModelConfig + splitter_config: dict[str, Any] = field(default_factory=dict) + user_name: str | None = None + + +@dataclass +class QueryConfig: + """Resolved configuration for querying a single knowledge base.""" + + knowledge_base_id: int + index_owner_user_id: int + retriever_config: RuntimeRetrieverConfig + embedding_model_config: RuntimeEmbeddingModelConfig + retrieval_config: RuntimeRetrievalConfig + user_name: str | None = None + + +@dataclass +class AdminResolvedConfig: + """Resolved configuration for admin operations (delete/purge/drop/list).""" + + index_owner_user_id: int + retriever_config: RuntimeRetrieverConfig + + +class ConfigResolutionError(ValueError): + """Raised when config resolution fails with a specific error code.""" + + def __init__(self, code: str, message: str) -> None: + self.code = code + super().__init__(message) + + +class ConfigResolver: + """Resolve runtime configs from database by knowledge_base_id + user_id.""" + + def resolve_index_config( + self, + db: Session, + *, + knowledge_base_id: int, + user_id: int, + document_id: int | None = None, + ) -> IndexConfig: + """Resolve all configs needed for document indexing.""" + kb = self._get_knowledge_base(db, knowledge_base_id) + index_owner_user_id = kb.user_id + user_name = self._get_user_name(db, user_id) + + retrieval_config = self._parse_kb_retrieval_config(kb) + + retriever_config = self._build_resolved_retriever_config( + db=db, + user_id=index_owner_user_id, + name=retrieval_config["retriever_name"], + namespace=retrieval_config["retriever_namespace"], + ) + embedding_model_config = self._build_resolved_embedding_model_config( + db=db, + user_id=index_owner_user_id, + model_name=retrieval_config["embedding_model_name"], + model_namespace=retrieval_config["embedding_model_namespace"], + user_name=user_name, + ) + + splitter_config: dict[str, Any] = {} + if document_id is not None: + splitter_config = self._get_splitter_config(db, document_id) + + return IndexConfig( + index_owner_user_id=index_owner_user_id, + retriever_config=retriever_config, + embedding_model_config=embedding_model_config, + splitter_config=splitter_config, + user_name=user_name, + ) + + def resolve_query_config( + self, + db: Session, + *, + knowledge_base_id: int, + user_id: int, + ) -> QueryConfig: + """Resolve configs needed for querying a single knowledge base.""" + kb = self._get_knowledge_base(db, knowledge_base_id) + index_owner_user_id = kb.user_id + user_name = self._get_user_name(db, user_id) + + retrieval_config = self._parse_kb_retrieval_config(kb) + + retriever_config = self._build_resolved_retriever_config( + db=db, + user_id=index_owner_user_id, + name=retrieval_config["retriever_name"], + namespace=retrieval_config["retriever_namespace"], + ) + embedding_model_config = self._build_resolved_embedding_model_config( + db=db, + user_id=index_owner_user_id, + model_name=retrieval_config["embedding_model_name"], + model_namespace=retrieval_config["embedding_model_namespace"], + user_name=user_name, + ) + + rc = retrieval_config + retrieval_mode = rc.get("retrieval_mode", "vector") + hybrid_weights = rc.get("hybrid_weights") or {} + runtime_retrieval_config = RuntimeRetrievalConfig( + top_k=rc.get("top_k", 20), + score_threshold=rc.get("score_threshold", 0.7), + retrieval_mode=retrieval_mode, + vector_weight=( + hybrid_weights.get("vector_weight") + if retrieval_mode == "hybrid" + else None + ), + keyword_weight=( + hybrid_weights.get("keyword_weight") + if retrieval_mode == "hybrid" + else None + ), + ) + + return QueryConfig( + knowledge_base_id=knowledge_base_id, + index_owner_user_id=index_owner_user_id, + retriever_config=retriever_config, + embedding_model_config=embedding_model_config, + retrieval_config=runtime_retrieval_config, + user_name=user_name, + ) + + def resolve_admin_config( + self, + db: Session, + *, + knowledge_base_id: int, + ) -> AdminResolvedConfig: + """Resolve config for admin operations (delete/purge/drop/list).""" + kb = self._get_knowledge_base(db, knowledge_base_id) + retrieval_config = self._parse_kb_retrieval_config(kb) + + retriever_config = self._build_resolved_retriever_config( + db=db, + user_id=kb.user_id, + name=retrieval_config["retriever_name"], + namespace=retrieval_config["retriever_namespace"], + ) + + return AdminResolvedConfig( + index_owner_user_id=kb.user_id, + retriever_config=retriever_config, + ) + + # --- Private methods --- + + def _get_knowledge_base(self, db: Session, knowledge_base_id: int) -> Kind: + """Get KB record or raise ConfigResolutionError.""" + kb = ( + db.query(Kind) + .filter( + Kind.id == knowledge_base_id, + Kind.kind == "KnowledgeBase", + Kind.is_active.is_(True), + ) + .first() + ) + if kb is None: + raise ConfigResolutionError( + "config_not_found", + f"Knowledge base {knowledge_base_id} not found", + ) + return kb + + def _parse_kb_retrieval_config(self, kb: Kind) -> dict[str, Any]: + """Parse KB's retrievalConfig from its JSON spec.""" + retrieval_config = (kb.json or {}).get("spec", {}).get("retrievalConfig") or {} + retriever_name = retrieval_config.get("retriever_name") + retriever_namespace = retrieval_config.get("retriever_namespace", "default") + embedding_config = retrieval_config.get("embedding_config") or {} + embedding_model_name = embedding_config.get("model_name") + embedding_model_namespace = embedding_config.get("model_namespace", "default") + + if not retriever_name: + raise ConfigResolutionError( + "config_incomplete", + f"Knowledge base {kb.id} has incomplete retrieval config (missing retriever_name)", + ) + if not embedding_model_name: + raise ConfigResolutionError( + "config_incomplete", + f"Knowledge base {kb.id} has incomplete embedding config", + ) + + return { + "retriever_name": retriever_name, + "retriever_namespace": retriever_namespace, + "embedding_model_name": embedding_model_name, + "embedding_model_namespace": embedding_model_namespace, + "top_k": retrieval_config.get("top_k", 20), + "score_threshold": retrieval_config.get("score_threshold", 0.7), + "retrieval_mode": retrieval_config.get("retrieval_mode", "vector"), + "hybrid_weights": retrieval_config.get("hybrid_weights"), + } + + def _build_resolved_retriever_config( + self, + *, + db: Session, + user_id: int, + name: str, + namespace: str, + ) -> RuntimeRetrieverConfig: + """Build resolved retriever config with decrypted credentials.""" + retriever = self._get_retriever_kind( + db, user_id=user_id, name=name, namespace=namespace + ) + if retriever is None: + raise ConfigResolutionError( + "config_not_found", + f"Retriever {name} (namespace: {namespace}) not found", + ) + + spec = retriever.json or {} + storage_config = spec.get("spec", {}).get("storageConfig", {}) + + return RuntimeRetrieverConfig( + name=name, + namespace=namespace, + storage_config={ + "type": storage_config.get("type"), + "url": storage_config.get("url"), + "username": storage_config.get("username"), + "password": self._decrypt_optional_value( + storage_config.get("password") + ), + "apiKey": self._decrypt_optional_value(storage_config.get("apiKey")), + "indexStrategy": storage_config.get( + "indexStrategy", {"mode": "per_dataset"} + ), + "ext": storage_config.get("ext", {}), + }, + ) + + def _build_resolved_embedding_model_config( + self, + *, + db: Session, + user_id: int, + model_name: str, + model_namespace: str, + user_name: str | None, + ) -> RuntimeEmbeddingModelConfig: + """Build resolved embedding model config with decrypted API key.""" + model_kind = self._get_model_kind( + db=db, + user_id=user_id, + model_name=model_name, + model_namespace=model_namespace, + ) + if model_kind is None: + raise ConfigResolutionError( + "config_not_found", + f"Embedding model '{model_name}' not found in namespace '{model_namespace}'", + ) + + spec = (model_kind.json or {}).get("spec", {}) + model_config = spec.get("modelConfig", {}) + env = model_config.get("env", {}) + protocol = spec.get("protocol") or env.get("model") + custom_headers = env.get("custom_headers", {}) + if custom_headers and isinstance(custom_headers, dict): + custom_headers = process_custom_headers_placeholders( + custom_headers, user_name + ) + + embedding_config = spec.get("embeddingConfig", {}) + dimensions = embedding_config.get("dimensions") if embedding_config else None + + return RuntimeEmbeddingModelConfig( + model_name=model_name, + model_namespace=model_namespace, + resolved_config={ + "protocol": protocol, + "api_key": self._decrypt_optional_value(env.get("api_key")), + "base_url": env.get("base_url"), + "model_id": env.get("model_id"), + "custom_headers": ( + custom_headers if isinstance(custom_headers, dict) else {} + ), + "dimensions": dimensions, + }, + ) + + def _get_retriever_kind( + self, + db: Session, + *, + user_id: int, + name: str, + namespace: str, + ) -> Kind | None: + """Get Retriever Kind with priority: user's own > public (user_id=0).""" + if namespace == "default": + return ( + db.query(Kind) + .filter( + Kind.kind == "Retriever", + Kind.name == name, + Kind.namespace == namespace, + Kind.is_active.is_(True), + ) + .filter((Kind.user_id == user_id) | (Kind.user_id == 0)) + .order_by(Kind.user_id.desc()) + .first() + ) + # Group retriever: no user_id filter, fallback to public + kind = ( + db.query(Kind) + .filter( + Kind.kind == "Retriever", + Kind.name == name, + Kind.namespace == namespace, + Kind.is_active.is_(True), + ) + .first() + ) + if kind is not None: + return kind + # Fallback to public retriever + return ( + db.query(Kind) + .filter( + Kind.user_id == 0, + Kind.name == name, + Kind.kind == "Retriever", + Kind.namespace == "default", + Kind.is_active.is_(True), + ) + .first() + ) + + def _get_model_kind( + self, + *, + db: Session, + user_id: int, + model_name: str, + model_namespace: str, + ) -> Kind | None: + """Get Model Kind with priority: user's own > public (user_id=0).""" + if model_namespace == "default": + return ( + db.query(Kind) + .filter( + Kind.kind == "Model", + Kind.name == model_name, + Kind.namespace == model_namespace, + Kind.is_active.is_(True), + ) + .filter((Kind.user_id == user_id) | (Kind.user_id == 0)) + .order_by(Kind.user_id.desc()) + .first() + ) + return ( + db.query(Kind) + .filter( + Kind.kind == "Model", + Kind.name == model_name, + Kind.namespace == model_namespace, + Kind.is_active.is_(True), + ) + .first() + ) + + def _get_splitter_config(self, db: Session, document_id: int) -> dict[str, Any]: + """Get splitter_config from knowledge_documents table.""" + doc = ( + db.query(KnowledgeDocument) + .filter(KnowledgeDocument.id == document_id) + .first() + ) + if doc is None: + raise ConfigResolutionError( + "config_not_found", + f"Document {document_id} not found", + ) + return doc.splitter_config or {} + + def _get_user_name(self, db: Session, user_id: int) -> str | None: + """Get user_name from users table.""" + user = db.query(User).filter(User.id == user_id).first() + return user.user_name if user else None + + @staticmethod + def _decrypt_optional_value(value: Any) -> Any: + """Decrypt an optional encrypted value. Returns original if decryption fails.""" + if not value: + return value + try: + return decrypt_api_key(value) + except Exception: + return value diff --git a/knowledge_runtime/knowledge_runtime/services/index_executor.py b/knowledge_runtime/knowledge_runtime/services/index_executor.py index 974db6e96..882f67a89 100644 --- a/knowledge_runtime/knowledge_runtime/services/index_executor.py +++ b/knowledge_runtime/knowledge_runtime/services/index_executor.py @@ -9,11 +9,14 @@ import logging from typing import Any +from sqlalchemy.orm import Session + from knowledge_engine.embedding.factory import ( create_embedding_model_from_runtime_config, ) from knowledge_engine.services.document_service import DocumentService from knowledge_engine.storage.factory import create_storage_backend_from_runtime_config +from knowledge_runtime.services.config_resolver import ConfigResolver from knowledge_runtime.services.content_fetcher import ContentFetcher from shared.models import RemoteIndexRequest @@ -24,19 +27,22 @@ class IndexExecutor: """Executes document indexing operations. This executor: - 1. Fetches binary content from the ContentRef - 2. Creates storage backend and embedding model from runtime configs - 3. Indexes the document using DocumentService + 1. Resolves configs from the database using ConfigResolver + 2. Fetches binary content from the ContentRef + 3. Creates storage backend and embedding model from resolved configs + 4. Indexes the document using DocumentService """ - def __init__(self) -> None: + def __init__(self, db: Session) -> None: + self._db = db + self._config_resolver = ConfigResolver() self._content_fetcher = ContentFetcher() async def execute(self, request: RemoteIndexRequest) -> dict[str, Any]: """Execute the indexing operation. Args: - request: The index request containing content reference and configs. + request: The index request (reference mode - configs resolved from DB). Returns: Indexing result with chunk_count, doc_ref, etc. @@ -45,6 +51,14 @@ async def execute(self, request: RemoteIndexRequest) -> dict[str, Any]: ValueError: If required configuration is missing. ContentFetchError: If content fetching fails. """ + # Resolve configs from database + config = self._config_resolver.resolve_index_config( + db=self._db, + knowledge_base_id=request.knowledge_base_id, + user_id=request.user_id, + document_id=request.document_id, + ) + # Fetch content from the content reference binary_data, source_file, file_extension = await self._content_fetcher.fetch( request.content_ref @@ -56,14 +70,12 @@ async def execute(self, request: RemoteIndexRequest) -> dict[str, Any]: if request.file_extension: file_extension = request.file_extension - # Create storage backend from retriever config + # Create storage backend and embedding model from resolved configs storage_backend = create_storage_backend_from_runtime_config( - request.retriever_config + config.retriever_config ) - - # Create embedding model from config embed_model = create_embedding_model_from_runtime_config( - request.embedding_model_config + config.embedding_model_config ) # Create document service @@ -73,8 +85,10 @@ async def execute(self, request: RemoteIndexRequest) -> dict[str, Any]: knowledge_id = str(request.knowledge_base_id) logger.info( - f"Indexing document for knowledge_base_id={request.knowledge_base_id}, " - f"source_file={source_file}, user_id={request.index_owner_user_id}" + "Indexing document for knowledge_base_id=%d, source_file=%s, user_id=%d", + request.knowledge_base_id, + source_file, + config.index_owner_user_id, ) # Index the document @@ -84,16 +98,15 @@ async def execute(self, request: RemoteIndexRequest) -> dict[str, Any]: source_file=source_file, file_extension=file_extension, embed_model=embed_model, - user_id=request.index_owner_user_id, - splitter_config=request.splitter_config.model_dump( - mode="json", exclude_none=True - ), + user_id=config.index_owner_user_id, + splitter_config=config.splitter_config, document_id=request.document_id, ) logger.info( - f"Indexing complete: chunk_count={result.get('chunk_count')}, " - f"doc_ref={result.get('doc_ref')}" + "Indexing complete: chunk_count=%s, doc_ref=%s", + result.get("chunk_count"), + result.get("doc_ref"), ) return result diff --git a/knowledge_runtime/knowledge_runtime/services/query_executor.py b/knowledge_runtime/knowledge_runtime/services/query_executor.py index a8c0c3323..df994483c 100644 --- a/knowledge_runtime/knowledge_runtime/services/query_executor.py +++ b/knowledge_runtime/knowledge_runtime/services/query_executor.py @@ -9,13 +9,15 @@ import logging from typing import Any +from sqlalchemy.orm import Session + from knowledge_engine.embedding.factory import ( create_embedding_model_from_runtime_config, ) from knowledge_engine.query.executor import QueryExecutor as KnowledgeQueryExecutor from knowledge_engine.storage.factory import create_storage_backend_from_runtime_config +from knowledge_runtime.services.config_resolver import ConfigResolver from shared.models import ( - RemoteKnowledgeBaseQueryConfig, RemoteQueryRecord, RemoteQueryRequest, RemoteQueryResponse, @@ -28,27 +30,32 @@ class QueryExecutor: """Executes RAG query operations. This executor: - 1. Creates storage backends and embedding models for each knowledge base - 2. Executes queries against each KB - 3. Aggregates and sorts results by score + 1. Resolves configs for each knowledge base from the database + 2. Creates storage backends and embedding models for each KB + 3. Executes queries against each KB + 4. Aggregates and sorts results by score """ + def __init__(self, db: Session) -> None: + self._db = db + self._config_resolver = ConfigResolver() + async def execute(self, request: RemoteQueryRequest) -> RemoteQueryResponse: """Execute the query operation. Args: - request: The query request containing query text and KB configs. + request: The query request (reference mode - configs resolved from DB). Returns: Query response with ranked records. """ all_records: list[RemoteQueryRecord] = [] - # Query each knowledge base - for kb_config in request.knowledge_base_configs: + # Resolve configs for each knowledge base + for knowledge_base_id in request.knowledge_base_ids: records = await self._query_knowledge_base( request=request, - kb_config=kb_config, + knowledge_base_id=knowledge_base_id, ) all_records.extend(records) @@ -62,8 +69,10 @@ async def execute(self, request: RemoteQueryRequest) -> RemoteQueryResponse: ) logger.info( - f"Query complete: query='{request.query[:50]}...', " - f"total_results={len(all_records)}, returned={len(limited_records)}" + "Query complete: query='%s...', total_results=%d, returned=%d", + request.query[:50], + len(all_records), + len(limited_records), ) return RemoteQueryResponse( @@ -75,25 +84,30 @@ async def execute(self, request: RemoteQueryRequest) -> RemoteQueryResponse: async def _query_knowledge_base( self, request: RemoteQueryRequest, - kb_config: RemoteKnowledgeBaseQueryConfig, + knowledge_base_id: int, ) -> list[RemoteQueryRecord]: """Query a single knowledge base. Args: request: The original query request. - kb_config: Configuration for this specific knowledge base. + knowledge_base_id: ID of the knowledge base to query. Returns: List of records from this knowledge base. """ - # Create storage backend - storage_backend = create_storage_backend_from_runtime_config( - kb_config.retriever_config + # Resolve config from database + config = self._config_resolver.resolve_query_config( + db=self._db, + knowledge_base_id=knowledge_base_id, + user_id=request.user_id, ) - # Create embedding model + # Create storage backend and embedding model + storage_backend = create_storage_backend_from_runtime_config( + config.retriever_config + ) embed_model = create_embedding_model_from_runtime_config( - kb_config.embedding_model_config + config.embedding_model_config ) # Create query executor @@ -102,16 +116,14 @@ async def _query_knowledge_base( embed_model=embed_model, ) - # Build knowledge_id - knowledge_id = str(kb_config.knowledge_base_id) - # Execute query + knowledge_id = str(knowledge_base_id) result = await executor.execute( knowledge_id=knowledge_id, query=request.query, - retrieval_config=kb_config.retrieval_config, + retrieval_config=config.retrieval_config, metadata_condition=request.metadata_condition, - user_id=kb_config.index_owner_user_id, + user_id=config.index_owner_user_id, ) # Convert to RemoteQueryRecord format @@ -123,32 +135,25 @@ async def _query_knowledge_base( title=record.get("title", ""), score=record.get("score"), metadata=record.get("metadata"), - knowledge_base_id=kb_config.knowledge_base_id, + knowledge_base_id=knowledge_base_id, document_id=self._extract_document_id(record), ) ) logger.info( - f"Queried KB: knowledge_base_id={kb_config.knowledge_base_id}, " - f"records={len(records)}" + "Queried KB: knowledge_base_id=%d, records=%d", + knowledge_base_id, + len(records), ) return records def _extract_document_id(self, record: dict[str, Any]) -> int | None: - """Extract document ID from record metadata. - - Args: - record: Query result record. - - Returns: - Document ID if found, None otherwise. - """ + """Extract document ID from record metadata.""" metadata = record.get("metadata") or {} doc_ref = metadata.get("doc_ref") if doc_ref and isinstance(doc_ref, str): try: - # doc_ref format is typically "doc_xxx" or numeric string if doc_ref.startswith("doc_"): return int(doc_ref[4:]) return int(doc_ref) @@ -157,14 +162,5 @@ def _extract_document_id(self, record: dict[str, Any]) -> int | None: return None def _estimate_tokens(self, text: str) -> int: - """Estimate token count for text. - - Uses a simple heuristic: ~4 characters per token. - - Args: - text: Text to estimate tokens for. - - Returns: - Estimated token count. - """ + """Estimate token count (~4 characters per token).""" return len(text) // 4 diff --git a/knowledge_runtime/pyproject.toml b/knowledge_runtime/pyproject.toml index 0a05d0e40..6de0148a5 100644 --- a/knowledge_runtime/pyproject.toml +++ b/knowledge_runtime/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "httpx[brotli,socks]>=0.26.0", "python-dotenv>=1.0.0", "tenacity>=8.2.3", + "sqlalchemy>=2.0.0", "wegent-knowledge-engine", "wegent-shared", ] diff --git a/knowledge_runtime/tests/conftest.py b/knowledge_runtime/tests/conftest.py index 38b650b17..aaf2b60d8 100644 --- a/knowledge_runtime/tests/conftest.py +++ b/knowledge_runtime/tests/conftest.py @@ -4,14 +4,20 @@ """Pytest configuration for knowledge_runtime tests.""" +from unittest.mock import MagicMock + import pytest +from knowledge_runtime.services.config_resolver import ConfigResolver + +# --------------------------------------------------------------------------- +# Fixtures for admin/other tests +# --------------------------------------------------------------------------- + @pytest.fixture def mock_storage_backend(): """Create a mock storage backend for testing.""" - from unittest.mock import MagicMock - backend = MagicMock() backend.test_connection.return_value = True backend.retrieve.return_value = {"records": []} @@ -25,7 +31,127 @@ def mock_storage_backend(): @pytest.fixture def mock_embed_model(): """Create a mock embedding model for testing.""" - from unittest.mock import MagicMock + model = MagicMock() + return model + + +# --------------------------------------------------------------------------- +# Fixtures for ConfigResolver tests +# --------------------------------------------------------------------------- + + +@pytest.fixture +def resolver() -> ConfigResolver: + """Create a ConfigResolver instance.""" + return ConfigResolver() + + +@pytest.fixture +def mock_db() -> MagicMock: + """Create a mock database session.""" + return MagicMock() + + +# --------------------------------------------------------------------------- +# Factory helpers for ConfigResolver tests +# --------------------------------------------------------------------------- + +_SENTINEL = object() + + +def _make_kb_kind( + knowledge_base_id: int = 1, + user_id: int = 42, + retrieval_config: dict | None = None, +) -> MagicMock: + """Create a mock KnowledgeBase Kind record.""" + if retrieval_config is None: + retrieval_config = { + "retriever_name": "test-retriever", + "retriever_namespace": "default", + "embedding_config": { + "model_name": "text-embedding-3-small", + "model_namespace": "default", + }, + "top_k": 10, + "score_threshold": 0.8, + "retrieval_mode": "vector", + } + kb = MagicMock() + kb.id = knowledge_base_id + kb.user_id = user_id + kb.kind = "KnowledgeBase" + kb.is_active = True + kb.json = {"spec": {"retrievalConfig": retrieval_config}} + return kb + +def _make_retriever_kind( + name: str = "test-retriever", + namespace: str = "default", + storage_config: dict | None = None, +) -> MagicMock: + """Create a mock Retriever Kind record.""" + if storage_config is None: + storage_config = { + "type": "qdrant", + "url": "http://localhost:6333", + "username": "admin", + "password": "encrypted_password", + "apiKey": "encrypted_api_key", + "indexStrategy": {"mode": "per_dataset"}, + "ext": {}, + } + retriever = MagicMock() + retriever.name = name + retriever.namespace = namespace + retriever.kind = "Retriever" + retriever.json = {"spec": {"storageConfig": storage_config}} + return retriever + + +def _make_model_kind( + model_name: str = "text-embedding-3-small", + model_namespace: str = "default", + spec: dict | None = None, +) -> MagicMock: + """Create a mock Model Kind record.""" + if spec is None: + spec = { + "protocol": "openai", + "modelConfig": { + "env": { + "api_key": "sk-encrypted-key", + "base_url": "https://api.openai.com/v1", + "model_id": "text-embedding-3-small", + "custom_headers": {}, + }, + }, + "embeddingConfig": {"dimensions": 1536}, + } model = MagicMock() + model.name = model_name + model.namespace = model_namespace + model.kind = "Model" + model.json = {"spec": spec} return model + + +def _make_document( + document_id: int = 100, splitter_config: dict | None = _SENTINEL +) -> MagicMock: + """Create a mock KnowledgeDocument record.""" + doc = MagicMock() + doc.id = document_id + if splitter_config is _SENTINEL: + splitter_config = {"chunk_size": 512} + doc.splitter_config = splitter_config + return doc + + +def _make_user(user_id: int = 42, user_name: str = "testuser") -> MagicMock: + """Create a mock User record.""" + user = MagicMock() + user.id = user_id + user.user_name = user_name + return user diff --git a/knowledge_runtime/tests/test_admin_executor.py b/knowledge_runtime/tests/test_admin_executor.py index 3ad9a53a7..386ba2367 100644 --- a/knowledge_runtime/tests/test_admin_executor.py +++ b/knowledge_runtime/tests/test_admin_executor.py @@ -9,19 +9,19 @@ import pytest from knowledge_runtime.services.admin_executor import AdminExecutor +from knowledge_runtime.services.config_resolver import AdminResolvedConfig from shared.models import ( RemoteDeleteDocumentIndexRequest, RemoteDropKnowledgeIndexRequest, RemoteListChunksRequest, RemotePurgeKnowledgeIndexRequest, - RemoteTestConnectionRequest, RuntimeRetrieverConfig, ) @pytest.fixture -def retriever_config(): - """Create a sample retriever config.""" +def mock_retriever_config(): + """Create a sample resolved retriever config.""" return RuntimeRetrieverConfig( name="test-retriever", namespace="default", @@ -34,8 +34,9 @@ def retriever_config(): @pytest.fixture def admin_executor(): - """Create an AdminExecutor instance.""" - return AdminExecutor() + """Create an AdminExecutor instance with a mock db session.""" + mock_db = MagicMock() + return AdminExecutor(db=mock_db) class TestAdminExecutor: @@ -43,14 +44,13 @@ class TestAdminExecutor: @pytest.mark.asyncio async def test_delete_document_index_success( - self, admin_executor, retriever_config + self, admin_executor, mock_retriever_config ) -> None: """Test successful document index deletion.""" request = RemoteDeleteDocumentIndexRequest( knowledge_base_id=1, + user_id=42, document_ref="doc_123", - index_owner_user_id=7, - retriever_config=retriever_config, ) mock_storage_backend = MagicMock() @@ -63,6 +63,13 @@ async def test_delete_document_index_success( "knowledge_runtime.services.admin_executor.create_storage_backend_from_runtime_config", return_value=mock_storage_backend, ): + admin_executor._config_resolver.resolve_admin_config = MagicMock( + return_value=AdminResolvedConfig( + index_owner_user_id=7, + retriever_config=mock_retriever_config, + ) + ) + result = await admin_executor.delete_document_index(request) assert result["status"] == "success" @@ -75,13 +82,12 @@ async def test_delete_document_index_success( @pytest.mark.asyncio async def test_purge_knowledge_index_success( - self, admin_executor, retriever_config + self, admin_executor, mock_retriever_config ) -> None: """Test successful knowledge base purge.""" request = RemotePurgeKnowledgeIndexRequest( knowledge_base_id=1, - index_owner_user_id=7, - retriever_config=retriever_config, + user_id=42, ) mock_storage_backend = MagicMock() @@ -94,6 +100,13 @@ async def test_purge_knowledge_index_success( "knowledge_runtime.services.admin_executor.create_storage_backend_from_runtime_config", return_value=mock_storage_backend, ): + admin_executor._config_resolver.resolve_admin_config = MagicMock( + return_value=AdminResolvedConfig( + index_owner_user_id=7, + retriever_config=mock_retriever_config, + ) + ) + result = await admin_executor.purge_knowledge_index(request) assert result["status"] == "success" @@ -105,13 +118,12 @@ async def test_purge_knowledge_index_success( @pytest.mark.asyncio async def test_drop_knowledge_index_success( - self, admin_executor, retriever_config + self, admin_executor, mock_retriever_config ) -> None: """Test successful knowledge base index drop.""" request = RemoteDropKnowledgeIndexRequest( knowledge_base_id=1, - index_owner_user_id=7, - retriever_config=retriever_config, + user_id=42, ) mock_storage_backend = MagicMock() @@ -121,6 +133,13 @@ async def test_drop_knowledge_index_success( "knowledge_runtime.services.admin_executor.create_storage_backend_from_runtime_config", return_value=mock_storage_backend, ): + admin_executor._config_resolver.resolve_admin_config = MagicMock( + return_value=AdminResolvedConfig( + index_owner_user_id=7, + retriever_config=mock_retriever_config, + ) + ) + result = await admin_executor.drop_knowledge_index(request) assert result["status"] == "success" @@ -130,12 +149,13 @@ async def test_drop_knowledge_index_success( ) @pytest.mark.asyncio - async def test_list_chunks_success(self, admin_executor, retriever_config) -> None: + async def test_list_chunks_success( + self, admin_executor, mock_retriever_config + ) -> None: """Test successful chunk listing.""" request = RemoteListChunksRequest( knowledge_base_id=1, - index_owner_user_id=7, - retriever_config=retriever_config, + user_id=42, max_chunks=100, ) @@ -162,6 +182,13 @@ async def test_list_chunks_success(self, admin_executor, retriever_config) -> No "knowledge_runtime.services.admin_executor.create_storage_backend_from_runtime_config", return_value=mock_storage_backend, ): + admin_executor._config_resolver.resolve_admin_config = MagicMock( + return_value=AdminResolvedConfig( + index_owner_user_id=7, + retriever_config=mock_retriever_config, + ) + ) + result = await admin_executor.list_chunks(request) assert result.total == 2 @@ -171,12 +198,13 @@ async def test_list_chunks_success(self, admin_executor, retriever_config) -> No assert result.chunks[1].content == "Chunk 2 content" @pytest.mark.asyncio - async def test_list_chunks_empty(self, admin_executor, retriever_config) -> None: + async def test_list_chunks_empty( + self, admin_executor, mock_retriever_config + ) -> None: """Test empty chunk listing.""" request = RemoteListChunksRequest( knowledge_base_id=1, - index_owner_user_id=7, - retriever_config=retriever_config, + user_id=42, ) mock_storage_backend = MagicMock() @@ -187,60 +215,26 @@ async def test_list_chunks_empty(self, admin_executor, retriever_config) -> None "knowledge_runtime.services.admin_executor.create_storage_backend_from_runtime_config", return_value=mock_storage_backend, ): + admin_executor._config_resolver.resolve_admin_config = MagicMock( + return_value=AdminResolvedConfig( + index_owner_user_id=7, + retriever_config=mock_retriever_config, + ) + ) + result = await admin_executor.list_chunks(request) assert result.total == 0 assert len(result.chunks) == 0 - @pytest.mark.asyncio - async def test_test_connection_success( - self, admin_executor, retriever_config - ) -> None: - """Test successful connection test.""" - request = RemoteTestConnectionRequest(retriever_config=retriever_config) - - mock_storage_backend = MagicMock() - mock_storage_backend.test_connection.return_value = True - - with patch( - "knowledge_runtime.services.admin_executor.create_storage_backend_from_runtime_config", - return_value=mock_storage_backend, - ): - result = await admin_executor.test_connection(request) - - assert result["success"] is True - assert result["message"] == "Connection successful" - - @pytest.mark.asyncio - async def test_test_connection_failure( - self, admin_executor, retriever_config - ) -> None: - """Test failed connection test.""" - request = RemoteTestConnectionRequest(retriever_config=retriever_config) - - mock_storage_backend = MagicMock() - mock_storage_backend.test_connection.side_effect = Exception( - "Connection refused" - ) - - with patch( - "knowledge_runtime.services.admin_executor.create_storage_backend_from_runtime_config", - return_value=mock_storage_backend, - ): - result = await admin_executor.test_connection(request) - - assert result["success"] is False - assert "Connection refused" in result["message"] - @pytest.mark.asyncio async def test_list_chunks_with_metadata_condition( - self, admin_executor, retriever_config + self, admin_executor, mock_retriever_config ) -> None: """Test chunk listing with metadata condition filter.""" request = RemoteListChunksRequest( knowledge_base_id=1, - index_owner_user_id=7, - retriever_config=retriever_config, + user_id=42, metadata_condition={"doc_ref": "doc_123"}, ) @@ -252,6 +246,13 @@ async def test_list_chunks_with_metadata_condition( "knowledge_runtime.services.admin_executor.create_storage_backend_from_runtime_config", return_value=mock_storage_backend, ): + admin_executor._config_resolver.resolve_admin_config = MagicMock( + return_value=AdminResolvedConfig( + index_owner_user_id=7, + retriever_config=mock_retriever_config, + ) + ) + await admin_executor.list_chunks(request) mock_storage_backend.get_all_chunks.assert_called_once() diff --git a/knowledge_runtime/tests/test_config_resolver_builders.py b/knowledge_runtime/tests/test_config_resolver_builders.py new file mode 100644 index 000000000..fff45c035 --- /dev/null +++ b/knowledge_runtime/tests/test_config_resolver_builders.py @@ -0,0 +1,273 @@ +# SPDX-FileCopyrightText: 2026 Weibo, Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for ConfigResolver builder methods.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from knowledge_runtime.services.config_resolver import ( + ConfigResolutionError, + ConfigResolver, +) +from shared.models import ( + RuntimeEmbeddingModelConfig, + RuntimeRetrieverConfig, +) + +from .conftest import _make_model_kind, _make_retriever_kind + + +class TestBuildResolvedRetrieverConfig: + """Tests for ConfigResolver._build_resolved_retriever_config.""" + + def test_success(self, resolver: ConfigResolver, mock_db: MagicMock) -> None: + """Test building resolved retriever config with decrypted credentials.""" + storage_config = { + "type": "qdrant", + "url": "http://localhost:6333", + "username": "admin", + "password": "enc_password", + "apiKey": "enc_api_key", + "indexStrategy": {"mode": "per_dataset"}, + "ext": {"timeout": 30}, + } + retriever = _make_retriever_kind(storage_config=storage_config) + + with ( + patch.object(resolver, "_get_retriever_kind", return_value=retriever), + patch.object( + resolver, + "_decrypt_optional_value", + side_effect=lambda v: ( + f"decrypted_{v}" if v and v.startswith("enc_") else v + ), + ), + ): + result = resolver._build_resolved_retriever_config( + db=mock_db, + user_id=42, + name="test-retriever", + namespace="default", + ) + + assert isinstance(result, RuntimeRetrieverConfig) + assert result.name == "test-retriever" + assert result.namespace == "default" + assert result.storage_config["type"] == "qdrant" + assert result.storage_config["url"] == "http://localhost:6333" + assert result.storage_config["password"] == "decrypted_enc_password" + assert result.storage_config["apiKey"] == "decrypted_enc_api_key" + assert result.storage_config["indexStrategy"] == {"mode": "per_dataset"} + assert result.storage_config["ext"] == {"timeout": 30} + + def test_default_index_strategy( + self, resolver: ConfigResolver, mock_db: MagicMock + ) -> None: + """Test that missing indexStrategy defaults to per_dataset.""" + storage_config = { + "type": "qdrant", + "url": "http://localhost:6333", + } + retriever = _make_retriever_kind(storage_config=storage_config) + + with ( + patch.object(resolver, "_get_retriever_kind", return_value=retriever), + patch.object(resolver, "_decrypt_optional_value", side_effect=lambda v: v), + ): + result = resolver._build_resolved_retriever_config( + db=mock_db, + user_id=42, + name="test-retriever", + namespace="default", + ) + + assert result.storage_config["indexStrategy"] == {"mode": "per_dataset"} + assert result.storage_config["ext"] == {} + + def test_retriever_not_found( + self, resolver: ConfigResolver, mock_db: MagicMock + ) -> None: + """Test that missing retriever raises ConfigResolutionError.""" + with patch.object(resolver, "_get_retriever_kind", return_value=None): + with pytest.raises(ConfigResolutionError) as exc_info: + resolver._build_resolved_retriever_config( + db=mock_db, + user_id=42, + name="missing-retriever", + namespace="default", + ) + + assert exc_info.value.code == "config_not_found" + assert "missing-retriever" in str(exc_info.value) + + +class TestBuildResolvedEmbeddingModelConfig: + """Tests for ConfigResolver._build_resolved_embedding_model_config.""" + + def test_success(self, resolver: ConfigResolver, mock_db: MagicMock) -> None: + """Test building resolved embedding model config.""" + spec = { + "protocol": "openai", + "modelConfig": { + "env": { + "api_key": "sk-enc-key", + "base_url": "https://api.openai.com/v1", + "model_id": "text-embedding-3-small", + "custom_headers": {"X-Custom": "value"}, + }, + }, + "embeddingConfig": {"dimensions": 1536}, + } + model = _make_model_kind(spec=spec) + + with ( + patch.object(resolver, "_get_model_kind", return_value=model), + patch.object( + resolver, + "_decrypt_optional_value", + side_effect=lambda v: f"dec_{v}" if v and v.startswith("sk-enc") else v, + ), + patch( + "knowledge_runtime.services.config_resolver.process_custom_headers_placeholders", + side_effect=lambda h, u: h, + ), + ): + result = resolver._build_resolved_embedding_model_config( + db=mock_db, + user_id=42, + model_name="text-embedding-3-small", + model_namespace="default", + user_name="testuser", + ) + + assert isinstance(result, RuntimeEmbeddingModelConfig) + assert result.model_name == "text-embedding-3-small" + assert result.model_namespace == "default" + assert result.resolved_config["protocol"] == "openai" + assert result.resolved_config["api_key"] == "dec_sk-enc-key" + assert result.resolved_config["base_url"] == "https://api.openai.com/v1" + assert result.resolved_config["model_id"] == "text-embedding-3-small" + assert result.resolved_config["dimensions"] == 1536 + + def test_protocol_fallback_to_env_model( + self, resolver: ConfigResolver, mock_db: MagicMock + ) -> None: + """Test that protocol falls back to env.model when spec.protocol is absent.""" + spec = { + "modelConfig": { + "env": { + "model": "anthropic", + "api_key": "sk-test", + "base_url": "https://api.anthropic.com", + "model_id": "claude-embedding", + "custom_headers": {}, + }, + }, + "embeddingConfig": {}, + } + model = _make_model_kind(spec=spec) + + with ( + patch.object(resolver, "_get_model_kind", return_value=model), + patch.object(resolver, "_decrypt_optional_value", side_effect=lambda v: v), + patch( + "knowledge_runtime.services.config_resolver.process_custom_headers_placeholders", + side_effect=lambda h, u: h, + ), + ): + result = resolver._build_resolved_embedding_model_config( + db=mock_db, + user_id=42, + model_name="claude-embedding", + model_namespace="default", + user_name="testuser", + ) + + assert result.resolved_config["protocol"] == "anthropic" + + def test_no_embedding_config_section( + self, resolver: ConfigResolver, mock_db: MagicMock + ) -> None: + """Test missing embeddingConfig section sets dimensions to None.""" + spec = { + "protocol": "openai", + "modelConfig": { + "env": { + "api_key": "sk-test", + "base_url": "https://api.openai.com/v1", + "model_id": "text-embedding-3-small", + "custom_headers": {}, + }, + }, + } + model = _make_model_kind(spec=spec) + + with ( + patch.object(resolver, "_get_model_kind", return_value=model), + patch.object(resolver, "_decrypt_optional_value", side_effect=lambda v: v), + patch( + "knowledge_runtime.services.config_resolver.process_custom_headers_placeholders", + side_effect=lambda h, u: h, + ), + ): + result = resolver._build_resolved_embedding_model_config( + db=mock_db, + user_id=42, + model_name="text-embedding-3-small", + model_namespace="default", + user_name="testuser", + ) + + assert result.resolved_config["dimensions"] is None + + def test_model_not_found( + self, resolver: ConfigResolver, mock_db: MagicMock + ) -> None: + """Test missing model raises ConfigResolutionError.""" + with patch.object(resolver, "_get_model_kind", return_value=None): + with pytest.raises(ConfigResolutionError) as exc_info: + resolver._build_resolved_embedding_model_config( + db=mock_db, + user_id=42, + model_name="missing-model", + model_namespace="default", + user_name="testuser", + ) + + assert exc_info.value.code == "config_not_found" + assert "missing-model" in str(exc_info.value) + + def test_custom_headers_placeholders_processed( + self, resolver: ConfigResolver, mock_db: MagicMock + ) -> None: + """Test that custom_headers with placeholders are processed.""" + spec = { + "protocol": "openai", + "modelConfig": { + "env": { + "api_key": "sk-test", + "base_url": "https://api.openai.com/v1", + "model_id": "text-embedding-3-small", + "custom_headers": {"X-User": "${user.name}"}, + }, + }, + "embeddingConfig": {}, + } + model = _make_model_kind(spec=spec) + + with ( + patch.object(resolver, "_get_model_kind", return_value=model), + patch.object(resolver, "_decrypt_optional_value", side_effect=lambda v: v), + ): + result = resolver._build_resolved_embedding_model_config( + db=mock_db, + user_id=42, + model_name="text-embedding-3-small", + model_namespace="default", + user_name="alice", + ) + + assert result.resolved_config["custom_headers"]["X-User"] == "alice" diff --git a/knowledge_runtime/tests/test_config_resolver_errors.py b/knowledge_runtime/tests/test_config_resolver_errors.py new file mode 100644 index 000000000..84d6a3497 --- /dev/null +++ b/knowledge_runtime/tests/test_config_resolver_errors.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: 2026 Weibo, Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for ConfigResolutionError and process_custom_headers_placeholders.""" + +import pytest + +from knowledge_runtime.services.config_resolver import ConfigResolutionError +from shared.utils.placeholder import process_custom_headers_placeholders + + +class TestConfigResolutionError: + """Tests for ConfigResolutionError.""" + + def test_stores_error_code(self) -> None: + """Test that error code is stored correctly.""" + error = ConfigResolutionError("config_not_found", "Something was not found") + assert error.code == "config_not_found" + assert str(error) == "Something was not found" + + def test_is_value_error(self) -> None: + """Test that ConfigResolutionError is a ValueError.""" + error = ConfigResolutionError("config_incomplete", "Incomplete config") + assert isinstance(error, ValueError) + + def test_raises_and_catches(self) -> None: + """Test that ConfigResolutionError can be raised and caught.""" + with pytest.raises(ConfigResolutionError) as exc_info: + raise ConfigResolutionError("config_not_found", "KB not found") + + assert exc_info.value.code == "config_not_found" + assert "KB not found" in str(exc_info.value) + + +class TestProcessCustomHeadersPlaceholders: + """Tests for process_custom_headers_placeholders helper.""" + + def test_replaces_user_name_placeholder(self) -> None: + """Test ${user.name} placeholder is replaced.""" + headers = {"X-User": "${user.name}"} + result = process_custom_headers_placeholders(headers, user_name="alice") + assert result["X-User"] == "alice" + + def test_replaces_in_mixed_string(self) -> None: + """Test placeholder replacement within a longer string.""" + headers = {"Authorization": "Bearer ${user.name}-token"} + result = process_custom_headers_placeholders(headers, user_name="bob") + assert result["Authorization"] == "Bearer bob-token" + + def test_no_placeholders(self) -> None: + """Test headers without placeholders pass through unchanged.""" + headers = {"X-Custom": "static-value"} + result = process_custom_headers_placeholders(headers, user_name="alice") + assert result["X-Custom"] == "static-value" + + def test_none_user_name(self) -> None: + """Test placeholder with None user_name uses empty string.""" + headers = {"X-User": "${user.name}"} + result = process_custom_headers_placeholders(headers, user_name=None) + assert result["X-User"] == "" + + def test_empty_headers(self) -> None: + """Test empty headers dict returns empty dict.""" + result = process_custom_headers_placeholders({}, user_name="alice") + assert result == {} + + def test_none_headers(self) -> None: + """Test None headers returns None.""" + result = process_custom_headers_placeholders(None, user_name="alice") + assert result is None + + def test_non_string_values_preserved(self) -> None: + """Test non-string header values are preserved as-is.""" + headers = {"X-Count": 42, "X-Flag": True} + result = process_custom_headers_placeholders(headers, user_name="alice") + assert result["X-Count"] == 42 + assert result["X-Flag"] is True + + def test_multiple_placeholders(self) -> None: + """Test multiple placeholders in the same header value.""" + headers = {"X-Auth": "user=${user.name}&type=bearer"} + result = process_custom_headers_placeholders(headers, user_name="charlie") + assert result["X-Auth"] == "user=charlie&type=bearer" diff --git a/knowledge_runtime/tests/test_config_resolver_helpers.py b/knowledge_runtime/tests/test_config_resolver_helpers.py new file mode 100644 index 000000000..7a96bec33 --- /dev/null +++ b/knowledge_runtime/tests/test_config_resolver_helpers.py @@ -0,0 +1,325 @@ +# SPDX-FileCopyrightText: 2026 Weibo, Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for ConfigResolver helper methods.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from knowledge_runtime.services.config_resolver import ( + ConfigResolutionError, + ConfigResolver, +) + +from .conftest import ( + _make_document, + _make_kb_kind, + _make_model_kind, + _make_retriever_kind, + _make_user, +) + + +class TestGetKnowledgeBase: + """Tests for ConfigResolver._get_knowledge_base.""" + + def test_found(self, resolver: ConfigResolver, mock_db: MagicMock) -> None: + """Test KB found returns the record.""" + kb = _make_kb_kind(knowledge_base_id=1, user_id=42) + mock_db.query.return_value.filter.return_value.first.return_value = kb + + result = resolver._get_knowledge_base(mock_db, 1) + + assert result.id == 1 + assert result.user_id == 42 + + def test_not_found(self, resolver: ConfigResolver, mock_db: MagicMock) -> None: + """Test KB not found raises ConfigResolutionError.""" + mock_db.query.return_value.filter.return_value.first.return_value = None + + with pytest.raises(ConfigResolutionError) as exc_info: + resolver._get_knowledge_base(mock_db, 999) + + assert exc_info.value.code == "config_not_found" + assert "999 not found" in str(exc_info.value) + + +class TestParseKbRetrievalConfig: + """Tests for ConfigResolver._parse_kb_retrieval_config.""" + + def test_full_config(self, resolver: ConfigResolver) -> None: + """Test parsing a fully populated retrieval config.""" + retrieval_config = { + "retriever_name": "my-retriever", + "retriever_namespace": "production", + "embedding_config": { + "model_name": "text-embedding-3-large", + "model_namespace": "custom", + }, + "top_k": 15, + "score_threshold": 0.6, + "retrieval_mode": "hybrid", + "hybrid_weights": {"vector_weight": 0.8, "keyword_weight": 0.2}, + } + kb = _make_kb_kind( + knowledge_base_id=1, user_id=42, retrieval_config=retrieval_config + ) + + result = resolver._parse_kb_retrieval_config(kb) + + assert result["retriever_name"] == "my-retriever" + assert result["retriever_namespace"] == "production" + assert result["embedding_model_name"] == "text-embedding-3-large" + assert result["embedding_model_namespace"] == "custom" + assert result["top_k"] == 15 + assert result["score_threshold"] == 0.6 + assert result["retrieval_mode"] == "hybrid" + assert result["hybrid_weights"] == {"vector_weight": 0.8, "keyword_weight": 0.2} + + def test_defaults(self, resolver: ConfigResolver) -> None: + """Test parsing with minimal config uses defaults.""" + retrieval_config = { + "retriever_name": "my-retriever", + "embedding_config": { + "model_name": "text-embedding-3-small", + }, + } + kb = _make_kb_kind( + knowledge_base_id=1, user_id=42, retrieval_config=retrieval_config + ) + + result = resolver._parse_kb_retrieval_config(kb) + + assert result["retriever_namespace"] == "default" + assert result["embedding_model_namespace"] == "default" + assert result["top_k"] == 20 + assert result["score_threshold"] == 0.7 + assert result["retrieval_mode"] == "vector" + assert result["hybrid_weights"] is None + + def test_missing_retriever_name(self, resolver: ConfigResolver) -> None: + """Test missing retriever_name raises ConfigResolutionError.""" + retrieval_config = { + "embedding_config": {"model_name": "text-embedding-3-small"}, + } + kb = _make_kb_kind( + knowledge_base_id=1, user_id=42, retrieval_config=retrieval_config + ) + + with pytest.raises(ConfigResolutionError) as exc_info: + resolver._parse_kb_retrieval_config(kb) + + assert exc_info.value.code == "config_incomplete" + assert "missing retriever_name" in str(exc_info.value) + + def test_missing_embedding_model_name(self, resolver: ConfigResolver) -> None: + """Test missing embedding model_name raises ConfigResolutionError.""" + retrieval_config = { + "retriever_name": "my-retriever", + "embedding_config": {}, + } + kb = _make_kb_kind( + knowledge_base_id=1, user_id=42, retrieval_config=retrieval_config + ) + + with pytest.raises(ConfigResolutionError) as exc_info: + resolver._parse_kb_retrieval_config(kb) + + assert exc_info.value.code == "config_incomplete" + assert "incomplete embedding config" in str(exc_info.value) + + def test_empty_retrieval_config(self, resolver: ConfigResolver) -> None: + """Test empty retrieval config raises ConfigResolutionError.""" + kb = _make_kb_kind(knowledge_base_id=1, user_id=42, retrieval_config={}) + + with pytest.raises(ConfigResolutionError) as exc_info: + resolver._parse_kb_retrieval_config(kb) + + assert exc_info.value.code == "config_incomplete" + + +class TestGetSplitterConfig: + """Tests for ConfigResolver._get_splitter_config.""" + + def test_found(self, resolver: ConfigResolver, mock_db: MagicMock) -> None: + """Test document found returns its splitter_config.""" + doc = _make_document(document_id=100, splitter_config={"chunk_size": 512}) + mock_db.query.return_value.filter.return_value.first.return_value = doc + + result = resolver._get_splitter_config(mock_db, 100) + + assert result == {"chunk_size": 512} + + def test_not_found(self, resolver: ConfigResolver, mock_db: MagicMock) -> None: + """Test document not found raises ConfigResolutionError.""" + mock_db.query.return_value.filter.return_value.first.return_value = None + + with pytest.raises(ConfigResolutionError) as exc_info: + resolver._get_splitter_config(mock_db, 999) + + assert exc_info.value.code == "config_not_found" + assert "999 not found" in str(exc_info.value) + + def test_null_splitter_config( + self, resolver: ConfigResolver, mock_db: MagicMock + ) -> None: + """Test document with null splitter_config returns empty dict.""" + doc = _make_document(document_id=100, splitter_config=None) + mock_db.query.return_value.filter.return_value.first.return_value = doc + + result = resolver._get_splitter_config(mock_db, 100) + + assert result == {} + + +class TestGetUserName: + """Tests for ConfigResolver._get_user_name.""" + + def test_found(self, resolver: ConfigResolver, mock_db: MagicMock) -> None: + """Test user found returns user_name.""" + user = _make_user(user_id=42, user_name="testuser") + mock_db.query.return_value.filter.return_value.first.return_value = user + + result = resolver._get_user_name(mock_db, 42) + + assert result == "testuser" + + def test_not_found(self, resolver: ConfigResolver, mock_db: MagicMock) -> None: + """Test user not found returns None.""" + mock_db.query.return_value.filter.return_value.first.return_value = None + + result = resolver._get_user_name(mock_db, 999) + + assert result is None + + +class TestDecryptOptionalValue: + """Tests for ConfigResolver._decrypt_optional_value.""" + + def test_none_value(self) -> None: + """Test None value returns None.""" + result = ConfigResolver._decrypt_optional_value(None) + assert result is None + + def test_empty_string(self) -> None: + """Test empty string returns empty string.""" + result = ConfigResolver._decrypt_optional_value("") + assert result == "" + + def test_successful_decrypt(self) -> None: + """Test successful decryption returns decrypted value.""" + with patch( + "knowledge_runtime.services.config_resolver.decrypt_api_key", + return_value="decrypted_value", + ): + result = ConfigResolver._decrypt_optional_value("encrypted_value") + assert result == "decrypted_value" + + def test_failed_decrypt_returns_original(self) -> None: + """Test failed decryption returns the original value.""" + with patch( + "knowledge_runtime.services.config_resolver.decrypt_api_key", + side_effect=Exception("Decryption failed"), + ): + result = ConfigResolver._decrypt_optional_value("encrypted_value") + assert result == "encrypted_value" + + def test_plain_text_api_key(self) -> None: + """Test plain text API key (sk- prefix) is returned as-is by decrypt_api_key.""" + with patch( + "knowledge_runtime.services.config_resolver.decrypt_api_key", + return_value="sk-plain-key", + ): + result = ConfigResolver._decrypt_optional_value("sk-plain-key") + assert result == "sk-plain-key" + + +class TestGetRetrieverKind: + """Tests for ConfigResolver._get_retriever_kind.""" + + def test_default_namespace_queries_with_user_priority( + self, resolver: ConfigResolver, mock_db: MagicMock + ) -> None: + """Test default namespace queries with user_id priority.""" + retriever = _make_retriever_kind() + mock_db.query.return_value.filter.return_value.filter.return_value.order_by.return_value.first.return_value = ( + retriever + ) + + result = resolver._get_retriever_kind( + mock_db, user_id=42, name="test-retriever", namespace="default" + ) + + assert result is not None + mock_db.query.assert_called() + + def test_group_namespace_queries_without_user_id( + self, resolver: ConfigResolver, mock_db: MagicMock + ) -> None: + """Test group namespace queries without user_id filter.""" + retriever = _make_retriever_kind(namespace="team-ns") + mock_db.query.return_value.filter.return_value.first.return_value = retriever + + result = resolver._get_retriever_kind( + mock_db, user_id=42, name="test-retriever", namespace="team-ns" + ) + + assert result is not None + + def test_group_namespace_fallback_to_public( + self, resolver: ConfigResolver, mock_db: MagicMock + ) -> None: + """Test group namespace falls back to public retriever (user_id=0).""" + public_retriever = _make_retriever_kind() + + # First query (group namespace) returns None, second query (public fallback) returns result + mock_db.query.return_value.filter.return_value.first.side_effect = [ + None, + public_retriever, + ] + + result = resolver._get_retriever_kind( + mock_db, user_id=42, name="test-retriever", namespace="team-ns" + ) + + assert result is not None + + +class TestGetModelKind: + """Tests for ConfigResolver._get_model_kind.""" + + def test_default_namespace_queries_with_user_priority( + self, resolver: ConfigResolver, mock_db: MagicMock + ) -> None: + """Test default namespace queries with user_id priority.""" + model = _make_model_kind() + mock_db.query.return_value.filter.return_value.filter.return_value.order_by.return_value.first.return_value = ( + model + ) + + result = resolver._get_model_kind( + db=mock_db, + user_id=42, + model_name="text-embedding-3-small", + model_namespace="default", + ) + + assert result is not None + + def test_group_namespace_queries_without_user_id( + self, resolver: ConfigResolver, mock_db: MagicMock + ) -> None: + """Test group namespace queries without user_id filter.""" + model = _make_model_kind(model_namespace="team-ns") + mock_db.query.return_value.filter.return_value.first.return_value = model + + result = resolver._get_model_kind( + db=mock_db, + user_id=42, + model_name="text-embedding-3-small", + model_namespace="team-ns", + ) + + assert result is not None diff --git a/knowledge_runtime/tests/test_config_resolver_resolve.py b/knowledge_runtime/tests/test_config_resolver_resolve.py new file mode 100644 index 000000000..3d8167bcb --- /dev/null +++ b/knowledge_runtime/tests/test_config_resolver_resolve.py @@ -0,0 +1,297 @@ +# SPDX-FileCopyrightText: 2026 Weibo, Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for ConfigResolver resolve_index_config and resolve_query_config.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from knowledge_runtime.services.config_resolver import ( + ConfigResolutionError, + ConfigResolver, + IndexConfig, + QueryConfig, +) +from shared.models import ( + RuntimeEmbeddingModelConfig, + RuntimeRetrievalConfig, + RuntimeRetrieverConfig, +) + +from .conftest import ( + _make_kb_kind, + _make_model_kind, + _make_retriever_kind, +) + + +class TestResolveIndexConfig: + """Tests for ConfigResolver.resolve_index_config.""" + + def test_success_with_document_id( + self, resolver: ConfigResolver, mock_db: MagicMock + ) -> None: + """Test successful index config resolution with document_id.""" + kb = _make_kb_kind(knowledge_base_id=1, user_id=42) + retriever = _make_retriever_kind() + model = _make_model_kind() + user = MagicMock() + user.id = 42 + user.user_name = "testuser" + + mock_db.query.return_value.filter.return_value.filter.return_value.order_by.return_value.first.return_value = ( + retriever + ) + mock_db.query.return_value.filter.return_value.first.side_effect = [ + kb, + user, + ] + + with ( + patch.object(resolver, "_get_knowledge_base", return_value=kb), + patch.object(resolver, "_get_user_name", return_value="testuser"), + patch.object( + resolver, + "_build_resolved_retriever_config", + return_value=RuntimeRetrieverConfig( + name="test-retriever", + namespace="default", + storage_config={"type": "qdrant", "url": "http://localhost:6333"}, + ), + ), + patch.object( + resolver, + "_build_resolved_embedding_model_config", + return_value=RuntimeEmbeddingModelConfig( + model_name="text-embedding-3-small", + model_namespace="default", + resolved_config={"protocol": "openai"}, + ), + ), + patch.object( + resolver, + "_get_splitter_config", + return_value={"chunk_size": 1024}, + ), + ): + result = resolver.resolve_index_config( + mock_db, + knowledge_base_id=1, + user_id=42, + document_id=100, + ) + + assert isinstance(result, IndexConfig) + assert result.index_owner_user_id == 42 + assert result.user_name == "testuser" + assert result.splitter_config == {"chunk_size": 1024} + assert result.retriever_config.name == "test-retriever" + assert result.embedding_model_config.model_name == "text-embedding-3-small" + + def test_success_without_document_id( + self, resolver: ConfigResolver, mock_db: MagicMock + ) -> None: + """Test index config resolution without document_id yields empty splitter_config.""" + kb = _make_kb_kind(knowledge_base_id=1, user_id=42) + + with ( + patch.object(resolver, "_get_knowledge_base", return_value=kb), + patch.object(resolver, "_get_user_name", return_value="testuser"), + patch.object( + resolver, + "_build_resolved_retriever_config", + return_value=RuntimeRetrieverConfig( + name="test-retriever", + namespace="default", + storage_config={}, + ), + ), + patch.object( + resolver, + "_build_resolved_embedding_model_config", + return_value=RuntimeEmbeddingModelConfig( + model_name="text-embedding-3-small", + model_namespace="default", + resolved_config={}, + ), + ), + ): + result = resolver.resolve_index_config( + mock_db, + knowledge_base_id=1, + user_id=42, + document_id=None, + ) + + assert result.splitter_config == {} + + def test_kb_not_found(self, resolver: ConfigResolver, mock_db: MagicMock) -> None: + """Test that ConfigResolutionError is raised when KB is not found.""" + with patch.object( + resolver, + "_get_knowledge_base", + side_effect=ConfigResolutionError( + "config_not_found", "Knowledge base 999 not found" + ), + ): + with pytest.raises(ConfigResolutionError) as exc_info: + resolver.resolve_index_config( + mock_db, + knowledge_base_id=999, + user_id=42, + ) + assert exc_info.value.code == "config_not_found" + + +class TestResolveQueryConfig: + """Tests for ConfigResolver.resolve_query_config.""" + + def test_success(self, resolver: ConfigResolver, mock_db: MagicMock) -> None: + """Test successful query config resolution.""" + kb = _make_kb_kind(knowledge_base_id=1, user_id=42) + + with ( + patch.object(resolver, "_get_knowledge_base", return_value=kb), + patch.object(resolver, "_get_user_name", return_value="testuser"), + patch.object( + resolver, + "_build_resolved_retriever_config", + return_value=RuntimeRetrieverConfig( + name="test-retriever", + namespace="default", + storage_config={"type": "qdrant"}, + ), + ), + patch.object( + resolver, + "_build_resolved_embedding_model_config", + return_value=RuntimeEmbeddingModelConfig( + model_name="text-embedding-3-small", + model_namespace="default", + resolved_config={"protocol": "openai"}, + ), + ), + ): + result = resolver.resolve_query_config( + mock_db, + knowledge_base_id=1, + user_id=42, + ) + + assert isinstance(result, QueryConfig) + assert result.knowledge_base_id == 1 + assert result.index_owner_user_id == 42 + assert result.user_name == "testuser" + assert result.retriever_config.name == "test-retriever" + assert result.embedding_model_config.model_name == "text-embedding-3-small" + assert isinstance(result.retrieval_config, RuntimeRetrievalConfig) + assert result.retrieval_config.top_k == 10 + assert result.retrieval_config.score_threshold == 0.8 + assert result.retrieval_config.retrieval_mode == "vector" + + def test_hybrid_retrieval_mode( + self, resolver: ConfigResolver, mock_db: MagicMock + ) -> None: + """Test query config with hybrid retrieval mode includes weights.""" + retrieval_config = { + "retriever_name": "test-retriever", + "retriever_namespace": "default", + "embedding_config": { + "model_name": "text-embedding-3-small", + "model_namespace": "default", + }, + "top_k": 5, + "score_threshold": 0.5, + "retrieval_mode": "hybrid", + "hybrid_weights": { + "vector_weight": 0.7, + "keyword_weight": 0.3, + }, + } + kb = _make_kb_kind( + knowledge_base_id=1, user_id=42, retrieval_config=retrieval_config + ) + + with ( + patch.object(resolver, "_get_knowledge_base", return_value=kb), + patch.object(resolver, "_get_user_name", return_value="testuser"), + patch.object( + resolver, + "_build_resolved_retriever_config", + return_value=RuntimeRetrieverConfig( + name="test-retriever", + namespace="default", + storage_config={}, + ), + ), + patch.object( + resolver, + "_build_resolved_embedding_model_config", + return_value=RuntimeEmbeddingModelConfig( + model_name="text-embedding-3-small", + model_namespace="default", + resolved_config={}, + ), + ), + ): + result = resolver.resolve_query_config( + mock_db, + knowledge_base_id=1, + user_id=42, + ) + + assert result.retrieval_config.retrieval_mode == "hybrid" + assert result.retrieval_config.vector_weight == 0.7 + assert result.retrieval_config.keyword_weight == 0.3 + + def test_default_retrieval_values( + self, resolver: ConfigResolver, mock_db: MagicMock + ) -> None: + """Test query config with minimal retrieval config uses defaults.""" + retrieval_config = { + "retriever_name": "test-retriever", + "retriever_namespace": "default", + "embedding_config": { + "model_name": "text-embedding-3-small", + "model_namespace": "default", + }, + } + kb = _make_kb_kind( + knowledge_base_id=1, user_id=42, retrieval_config=retrieval_config + ) + + with ( + patch.object(resolver, "_get_knowledge_base", return_value=kb), + patch.object(resolver, "_get_user_name", return_value="testuser"), + patch.object( + resolver, + "_build_resolved_retriever_config", + return_value=RuntimeRetrieverConfig( + name="test-retriever", + namespace="default", + storage_config={}, + ), + ), + patch.object( + resolver, + "_build_resolved_embedding_model_config", + return_value=RuntimeEmbeddingModelConfig( + model_name="text-embedding-3-small", + model_namespace="default", + resolved_config={}, + ), + ), + ): + result = resolver.resolve_query_config( + mock_db, + knowledge_base_id=1, + user_id=42, + ) + + assert result.retrieval_config.top_k == 20 + assert result.retrieval_config.score_threshold == 0.7 + assert result.retrieval_config.retrieval_mode == "vector" + assert result.retrieval_config.vector_weight is None + assert result.retrieval_config.keyword_weight is None diff --git a/knowledge_runtime/tests/test_index_executor.py b/knowledge_runtime/tests/test_index_executor.py index e8bea087d..a8b5efb8c 100644 --- a/knowledge_runtime/tests/test_index_executor.py +++ b/knowledge_runtime/tests/test_index_executor.py @@ -8,6 +8,7 @@ import pytest +from knowledge_runtime.services.config_resolver import IndexConfig from knowledge_runtime.services.index_executor import IndexExecutor from shared.models import ( PresignedUrlContentRef, @@ -18,11 +19,9 @@ @pytest.fixture -def index_request(): - """Create a sample index request.""" - return RemoteIndexRequest( - knowledge_base_id=1, - document_id=100, +def mock_index_config(): + """Create a sample resolved IndexConfig.""" + return IndexConfig( index_owner_user_id=7, retriever_config=RuntimeRetrieverConfig( name="test-retriever", @@ -40,6 +39,17 @@ def index_request(): "api_key": "test-key", }, ), + splitter_config={"chunk_size": 500, "chunk_overlap": 50}, + ) + + +@pytest.fixture +def index_request(): + """Create a sample index request (reference mode).""" + return RemoteIndexRequest( + knowledge_base_id=1, + user_id=42, + document_id=100, content_ref=PresignedUrlContentRef( kind="presigned_url", url="https://storage.example.com/bucket/test.pdf", @@ -53,11 +63,12 @@ class TestIndexExecutor: """Tests for IndexExecutor.""" @pytest.mark.asyncio - async def test_execute_success(self, index_request) -> None: + async def test_execute_success(self, index_request, mock_index_config) -> None: """Test successful index execution.""" mock_storage_backend = MagicMock() mock_embed_model = MagicMock() mock_document_service = MagicMock() + mock_db = MagicMock() mock_document_service.index_document_from_binary = AsyncMock( return_value={ @@ -84,8 +95,13 @@ async def test_execute_success(self, index_request) -> None: return_value=mock_document_service, ), patch("httpx.AsyncClient") as mock_client, - patch("knowledge_runtime.config._settings", None), # Reset settings cache + patch("knowledge_runtime.config._settings", None), ): + executor = IndexExecutor(db=mock_db) + executor._config_resolver.resolve_index_config = MagicMock( + return_value=mock_index_config + ) + mock_response = MagicMock() mock_response.content = b"test content" mock_response.raise_for_status = MagicMock() @@ -93,21 +109,29 @@ async def test_execute_success(self, index_request) -> None: return_value=mock_response ) - executor = IndexExecutor() result = await executor.execute(index_request) assert result["chunk_count"] == 5 assert result["doc_ref"] == "100" mock_document_service.index_document_from_binary.assert_called_once() + # Verify ConfigResolver was called with correct args + executor._config_resolver.resolve_index_config.assert_called_once_with( + db=mock_db, + knowledge_base_id=1, + user_id=42, + document_id=100, + ) + @pytest.mark.asyncio async def test_execute_uses_request_metadata_over_fetched( - self, index_request + self, index_request, mock_index_config ) -> None: """Test that request metadata overrides fetched content metadata.""" mock_storage_backend = MagicMock() mock_embed_model = MagicMock() mock_document_service = MagicMock() + mock_db = MagicMock() mock_document_service.index_document_from_binary = AsyncMock( return_value={ @@ -130,8 +154,13 @@ async def test_execute_uses_request_metadata_over_fetched( return_value=mock_document_service, ), patch("httpx.AsyncClient") as mock_client, - patch("knowledge_runtime.config._settings", None), # Reset settings cache + patch("knowledge_runtime.config._settings", None), ): + executor = IndexExecutor(db=mock_db) + executor._config_resolver.resolve_index_config = MagicMock( + return_value=mock_index_config + ) + mock_response = MagicMock() # Content from fetch has different filename mock_response.content = b"content" @@ -140,7 +169,6 @@ async def test_execute_uses_request_metadata_over_fetched( return_value=mock_response ) - executor = IndexExecutor() await executor.execute(index_request) call_kwargs = mock_document_service.index_document_from_binary.call_args.kwargs @@ -149,31 +177,99 @@ async def test_execute_uses_request_metadata_over_fetched( assert call_kwargs["file_extension"] == ".pdf" @pytest.mark.asyncio - async def test_execute_content_fetch_error_propagates(self, index_request) -> None: + async def test_execute_uses_resolved_configs( + self, index_request, mock_index_config + ) -> None: + """Test that resolved configs are passed to storage/embedding factories.""" + mock_storage_backend = MagicMock() + mock_embed_model = MagicMock() + mock_document_service = MagicMock() + mock_db = MagicMock() + + mock_document_service.index_document_from_binary = AsyncMock( + return_value={"chunk_count": 1, "doc_ref": "100"} + ) + + with ( + patch( + "knowledge_runtime.services.index_executor.create_storage_backend_from_runtime_config", + return_value=mock_storage_backend, + ) as mock_create_storage, + patch( + "knowledge_runtime.services.index_executor.create_embedding_model_from_runtime_config", + return_value=mock_embed_model, + ) as mock_create_embed, + patch( + "knowledge_runtime.services.index_executor.DocumentService", + return_value=mock_document_service, + ), + patch("httpx.AsyncClient") as mock_client, + patch("knowledge_runtime.config._settings", None), + ): + executor = IndexExecutor(db=mock_db) + executor._config_resolver.resolve_index_config = MagicMock( + return_value=mock_index_config + ) + + mock_response = MagicMock() + mock_response.content = b"content" + mock_response.raise_for_status = MagicMock() + mock_client.return_value.__aenter__.return_value.get = AsyncMock( + return_value=mock_response + ) + + await executor.execute(index_request) + + # Verify factories were called with resolved configs + mock_create_storage.assert_called_once_with(mock_index_config.retriever_config) + mock_create_embed.assert_called_once_with( + mock_index_config.embedding_model_config + ) + + # Verify document service was called with resolved user_id and splitter_config + call_kwargs = mock_document_service.index_document_from_binary.call_args.kwargs + assert call_kwargs["user_id"] == 7 + assert call_kwargs["splitter_config"] == { + "chunk_size": 500, + "chunk_overlap": 50, + } + + @pytest.mark.asyncio + async def test_execute_content_fetch_error_propagates( + self, index_request, mock_index_config + ) -> None: """Test that content fetch errors propagate correctly.""" from knowledge_runtime.services.content_fetcher import ContentFetchError + mock_db = MagicMock() + with ( patch("httpx.AsyncClient") as mock_client, - patch("knowledge_runtime.config._settings", None), # Reset settings cache + patch("knowledge_runtime.config._settings", None), ): + executor = IndexExecutor(db=mock_db) + executor._config_resolver.resolve_index_config = MagicMock( + return_value=mock_index_config + ) + mock_client.return_value.__aenter__.return_value.get = AsyncMock( side_effect=ContentFetchError("Fetch failed", retryable=True) ) - executor = IndexExecutor() - with pytest.raises(ContentFetchError) as exc_info: await executor.execute(index_request) assert exc_info.value.retryable @pytest.mark.asyncio - async def test_execute_storage_error_propagates(self, index_request) -> None: + async def test_execute_storage_error_propagates( + self, index_request, mock_index_config + ) -> None: """Test that storage backend errors propagate correctly.""" mock_storage_backend = MagicMock() mock_embed_model = MagicMock() mock_document_service = MagicMock() + mock_db = MagicMock() mock_document_service.index_document_from_binary = AsyncMock( side_effect=ValueError("Storage connection failed") @@ -193,8 +289,13 @@ async def test_execute_storage_error_propagates(self, index_request) -> None: return_value=mock_document_service, ), patch("httpx.AsyncClient") as mock_client, - patch("knowledge_runtime.config._settings", None), # Reset settings cache + patch("knowledge_runtime.config._settings", None), ): + executor = IndexExecutor(db=mock_db) + executor._config_resolver.resolve_index_config = MagicMock( + return_value=mock_index_config + ) + mock_response = MagicMock() mock_response.content = b"content" mock_response.raise_for_status = MagicMock() @@ -202,7 +303,26 @@ async def test_execute_storage_error_propagates(self, index_request) -> None: return_value=mock_response ) - executor = IndexExecutor() - with pytest.raises(ValueError, match="Storage connection failed"): await executor.execute(index_request) + + @pytest.mark.asyncio + async def test_execute_config_resolution_error_propagates( + self, index_request + ) -> None: + """Test that config resolution errors propagate correctly.""" + from knowledge_runtime.services.config_resolver import ConfigResolutionError + + mock_db = MagicMock() + + executor = IndexExecutor(db=mock_db) + executor._config_resolver.resolve_index_config = MagicMock( + side_effect=ConfigResolutionError( + "config_not_found", "Knowledge base 1 not found" + ) + ) + + with pytest.raises(ConfigResolutionError) as exc_info: + await executor.execute(index_request) + + assert exc_info.value.code == "config_not_found" diff --git a/knowledge_runtime/tests/test_query_executor.py b/knowledge_runtime/tests/test_query_executor.py index 420edf3e1..aba03909e 100644 --- a/knowledge_runtime/tests/test_query_executor.py +++ b/knowledge_runtime/tests/test_query_executor.py @@ -8,9 +8,9 @@ import pytest +from knowledge_runtime.services.config_resolver import QueryConfig from knowledge_runtime.services.query_executor import QueryExecutor from shared.models import ( - RemoteKnowledgeBaseQueryConfig, RemoteQueryRequest, RemoteQueryResponse, RuntimeEmbeddingModelConfig, @@ -19,63 +19,42 @@ ) +def _make_query_config(knowledge_base_id: int = 1) -> QueryConfig: + """Create a sample resolved QueryConfig.""" + return QueryConfig( + knowledge_base_id=knowledge_base_id, + index_owner_user_id=7, + retriever_config=RuntimeRetrieverConfig( + name="test-retriever", + namespace="default", + storage_config={ + "type": "qdrant", + "url": "http://localhost:6333", + }, + ), + embedding_model_config=RuntimeEmbeddingModelConfig( + model_name="text-embedding-3-small", + model_namespace="default", + resolved_config={ + "protocol": "openai", + "api_key": "test-key", + }, + ), + retrieval_config=RuntimeRetrievalConfig( + top_k=5, + score_threshold=0.7, + ), + ) + + @pytest.fixture def query_request(): - """Create a sample query request.""" + """Create a sample query request (reference mode).""" return RemoteQueryRequest( knowledge_base_ids=[1, 2], + user_id=42, query="test query", max_results=10, - knowledge_base_configs=[ - RemoteKnowledgeBaseQueryConfig( - knowledge_base_id=1, - index_owner_user_id=7, - retriever_config=RuntimeRetrieverConfig( - name="retriever-1", - namespace="default", - storage_config={ - "type": "qdrant", - "url": "http://localhost:6333", - }, - ), - embedding_model_config=RuntimeEmbeddingModelConfig( - model_name="text-embedding-3-small", - model_namespace="default", - resolved_config={ - "protocol": "openai", - "api_key": "test-key", - }, - ), - retrieval_config=RuntimeRetrievalConfig( - top_k=5, - score_threshold=0.7, - ), - ), - RemoteKnowledgeBaseQueryConfig( - knowledge_base_id=2, - index_owner_user_id=7, - retriever_config=RuntimeRetrieverConfig( - name="retriever-2", - namespace="default", - storage_config={ - "type": "qdrant", - "url": "http://localhost:6333", - }, - ), - embedding_model_config=RuntimeEmbeddingModelConfig( - model_name="text-embedding-3-small", - model_namespace="default", - resolved_config={ - "protocol": "openai", - "api_key": "test-key", - }, - ), - retrieval_config=RuntimeRetrievalConfig( - top_k=5, - score_threshold=0.7, - ), - ), - ], ) @@ -94,7 +73,7 @@ async def test_execute_returns_aggregated_results(self, query_request) -> None: return_value={ "records": [ { - "content": "Result from KB1", + "content": "Result from KB", "title": "Doc1", "score": 0.95, "metadata": {"doc_ref": "doc_100"}, @@ -103,6 +82,10 @@ async def test_execute_returns_aggregated_results(self, query_request) -> None: } ) + mock_db = MagicMock() + config_1 = _make_query_config(1) + config_2 = _make_query_config(2) + with ( patch( "knowledge_runtime.services.query_executor.create_storage_backend_from_runtime_config", @@ -117,7 +100,11 @@ async def test_execute_returns_aggregated_results(self, query_request) -> None: return_value=mock_kb_executor, ), ): - executor = QueryExecutor() + executor = QueryExecutor(db=mock_db) + executor._config_resolver.resolve_query_config = MagicMock( + side_effect=[config_1, config_2] + ) + result = await executor.execute(query_request) assert isinstance(result, RemoteQueryResponse) @@ -149,6 +136,10 @@ async def mock_execute(**kwargs): mock_kb_executor = MagicMock() mock_kb_executor.execute = mock_execute + mock_db = MagicMock() + config_1 = _make_query_config(1) + config_2 = _make_query_config(2) + with ( patch( "knowledge_runtime.services.query_executor.create_storage_backend_from_runtime_config", @@ -163,7 +154,11 @@ async def mock_execute(**kwargs): return_value=mock_kb_executor, ), ): - executor = QueryExecutor() + executor = QueryExecutor(db=mock_db) + executor._config_resolver.resolve_query_config = MagicMock( + side_effect=[config_1, config_2] + ) + result = await executor.execute(query_request) # Should be sorted by score descending @@ -188,6 +183,10 @@ async def test_execute_respects_max_results(self, query_request) -> None: } ) + mock_db = MagicMock() + config_1 = _make_query_config(1) + config_2 = _make_query_config(2) + with ( patch( "knowledge_runtime.services.query_executor.create_storage_backend_from_runtime_config", @@ -202,7 +201,11 @@ async def test_execute_respects_max_results(self, query_request) -> None: return_value=mock_kb_executor, ), ): - executor = QueryExecutor() + executor = QueryExecutor(db=mock_db) + executor._config_resolver.resolve_query_config = MagicMock( + side_effect=[config_1, config_2] + ) + result = await executor.execute(query_request) assert len(result.records) == 1 @@ -217,6 +220,10 @@ async def test_execute_empty_results(self, query_request) -> None: mock_kb_executor = MagicMock() mock_kb_executor.execute = AsyncMock(return_value={"records": []}) + mock_db = MagicMock() + config_1 = _make_query_config(1) + config_2 = _make_query_config(2) + with ( patch( "knowledge_runtime.services.query_executor.create_storage_backend_from_runtime_config", @@ -231,17 +238,68 @@ async def test_execute_empty_results(self, query_request) -> None: return_value=mock_kb_executor, ), ): - executor = QueryExecutor() + executor = QueryExecutor(db=mock_db) + executor._config_resolver.resolve_query_config = MagicMock( + side_effect=[config_1, config_2] + ) + result = await executor.execute(query_request) assert result.total == 0 assert len(result.records) == 0 assert result.total_estimated_tokens == 0 + @pytest.mark.asyncio + async def test_execute_resolves_config_for_each_kb(self, query_request) -> None: + """Test that ConfigResolver is called for each knowledge_base_id.""" + mock_storage_backend = MagicMock() + mock_embed_model = MagicMock() + + mock_kb_executor = MagicMock() + mock_kb_executor.execute = AsyncMock(return_value={"records": []}) + + mock_db = MagicMock() + config_1 = _make_query_config(1) + config_2 = _make_query_config(2) + + with ( + patch( + "knowledge_runtime.services.query_executor.create_storage_backend_from_runtime_config", + return_value=mock_storage_backend, + ), + patch( + "knowledge_runtime.services.query_executor.create_embedding_model_from_runtime_config", + return_value=mock_embed_model, + ), + patch( + "knowledge_runtime.services.query_executor.KnowledgeQueryExecutor", + return_value=mock_kb_executor, + ), + ): + executor = QueryExecutor(db=mock_db) + mock_resolve = MagicMock(side_effect=[config_1, config_2]) + executor._config_resolver.resolve_query_config = mock_resolve + + await executor.execute(query_request) + + # ConfigResolver should be called twice, once for each KB + assert mock_resolve.call_count == 2 + mock_resolve.assert_any_call( + db=mock_db, + knowledge_base_id=1, + user_id=42, + ) + mock_resolve.assert_any_call( + db=mock_db, + knowledge_base_id=2, + user_id=42, + ) + @pytest.mark.asyncio async def test_extract_document_id_from_doc_ref(self) -> None: """Test document ID extraction from various doc_ref formats.""" - executor = QueryExecutor() + mock_db = MagicMock() + executor = QueryExecutor(db=mock_db) # Test "doc_XXX" format assert ( @@ -260,7 +318,8 @@ async def test_extract_document_id_from_doc_ref(self) -> None: def test_estimate_tokens(self) -> None: """Test token estimation heuristic.""" - executor = QueryExecutor() + mock_db = MagicMock() + executor = QueryExecutor(db=mock_db) # ~4 characters per token assert executor._estimate_tokens("test") == 1 # 4 chars diff --git a/knowledge_runtime/uv.lock b/knowledge_runtime/uv.lock index 4a2bf3115..b5c4d4431 100644 --- a/knowledge_runtime/uv.lock +++ b/knowledge_runtime/uv.lock @@ -823,7 +823,7 @@ name = "exceptiongroup" version = "1.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" } wheels = [ @@ -3623,6 +3623,7 @@ dependencies = [ { name = "pydantic" }, { name = "pydantic-settings" }, { name = "python-dotenv" }, + { name = "sqlalchemy" }, { name = "tenacity" }, { name = "uvicorn" }, { name = "wegent-knowledge-engine" }, @@ -3647,6 +3648,7 @@ requires-dist = [ { name = "pydantic", specifier = ">=2.5.3" }, { name = "pydantic-settings", specifier = ">=2.0.0" }, { name = "python-dotenv", specifier = ">=1.0.0" }, + { name = "sqlalchemy", specifier = ">=2.0.0" }, { name = "tenacity", specifier = ">=8.2.3" }, { name = "uvicorn", specifier = ">=0.30.0" }, { name = "wegent-knowledge-engine", editable = "../knowledge_engine" }, diff --git a/shared/models/__init__.py b/shared/models/__init__.py index 4ad0292b3..3dbe4b90b 100644 --- a/shared/models/__init__.py +++ b/shared/models/__init__.py @@ -63,10 +63,6 @@ RemoteQueryRequest, RemoteQueryResponse, RemoteRagError, - RemoteTestConnectionRequest, - RuntimeEmbeddingModelConfig, - RuntimeRetrievalConfig, - RuntimeRetrieverConfig, ) # OpenAI Request Converter @@ -101,6 +97,11 @@ TransportFactory, TransportType, ) +from .runtime_config import ( + RuntimeEmbeddingModelConfig, + RuntimeRetrievalConfig, + RuntimeRetrieverConfig, +) from .splitter_config import ( FlatChunkConfig, HierarchicalChunkConfig, @@ -140,7 +141,6 @@ "RemoteListChunksRequest", "RemoteListChunkRecord", "RemoteListChunksResponse", - "RemoteTestConnectionRequest", "RemoteQueryRequest", "RemoteQueryRecord", "RemoteQueryResponse", diff --git a/shared/models/knowledge_runtime_protocol.py b/shared/models/knowledge_runtime_protocol.py index a3be27ec2..054f53807 100644 --- a/shared/models/knowledge_runtime_protocol.py +++ b/shared/models/knowledge_runtime_protocol.py @@ -6,15 +6,15 @@ from __future__ import annotations -from collections import Counter from datetime import datetime from typing import Annotated, Any, Literal -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field -from .splitter_config import ( - NormalizedSplitterConfig, - build_runtime_default_splitter_config, +from shared.models.runtime_config import ( + RuntimeEmbeddingModelConfig, + RuntimeRetrievalConfig, + RuntimeRetrieverConfig, ) @@ -46,6 +46,7 @@ class PresignedUrlContentRef(KnowledgeRuntimeProtocolModel): Field(discriminator="kind"), ] + RetrievalPolicy = Literal[ "chunk_only", "summary_first", @@ -53,8 +54,6 @@ class PresignedUrlContentRef(KnowledgeRuntimeProtocolModel): "hybrid", ] -RetrievalMode = Literal["vector", "keyword", "hybrid"] - class KnowledgeRuntimeAuth(KnowledgeRuntimeProtocolModel): """Simple internal auth carrier for the runtime service.""" @@ -72,32 +71,6 @@ class RemoteRagError(KnowledgeRuntimeProtocolModel): details: dict[str, Any] | None = None -class RuntimeRetrieverConfig(KnowledgeRuntimeProtocolModel): - """Resolved retriever identity and storage configuration.""" - - name: str - namespace: str = "default" - storage_config: dict[str, Any] = Field(default_factory=dict) - - -class RuntimeEmbeddingModelConfig(KnowledgeRuntimeProtocolModel): - """Resolved embedding model configuration.""" - - model_name: str - model_namespace: str = "default" - resolved_config: dict[str, Any] = Field(default_factory=dict) - - -class RuntimeRetrievalConfig(KnowledgeRuntimeProtocolModel): - """Normalized retrieval config for a single knowledge base target.""" - - top_k: int = Field(default=20, gt=0) - score_threshold: float = Field(default=0.7, ge=0.0, le=1.0) - retrieval_mode: RetrievalMode = "vector" - vector_weight: float | None = Field(default=None, ge=0.0, le=1.0) - keyword_weight: float | None = Field(default=None, ge=0.0, le=1.0) - - class RemoteKnowledgeBaseQueryConfig(KnowledgeRuntimeProtocolModel): """Resolved execution config for one queryable knowledge base.""" @@ -109,102 +82,65 @@ class RemoteKnowledgeBaseQueryConfig(KnowledgeRuntimeProtocolModel): class RemoteIndexRequest(KnowledgeRuntimeProtocolModel): - """Index request sent from Backend to knowledge_runtime.""" + """Index request - reference mode. KR resolves configs from DB.""" knowledge_base_id: int + user_id: int document_id: int | None = None - index_owner_user_id: int - retriever_config: RuntimeRetrieverConfig - embedding_model_config: RuntimeEmbeddingModelConfig - splitter_config: NormalizedSplitterConfig = Field( - default_factory=build_runtime_default_splitter_config - ) source_file: str | None = None file_extension: str | None = None - index_families: list[str] = Field(default_factory=lambda: ["chunk_vector"]) content_ref: ContentRef trace_context: dict[str, Any] | None = None - user_name: str | None = None extensions: dict[str, Any] | None = None class RemoteDeleteDocumentIndexRequest(KnowledgeRuntimeProtocolModel): - """Delete-document-index request sent from Backend to knowledge_runtime.""" + """Delete-document-index request - reference mode.""" knowledge_base_id: int + user_id: int document_ref: str - index_owner_user_id: int | None = None - retriever_config: RuntimeRetrieverConfig - enabled_index_families: list[str] = Field(default_factory=lambda: ["chunk_vector"]) extensions: dict[str, Any] | None = None class RemotePurgeKnowledgeIndexRequest(KnowledgeRuntimeProtocolModel): - """Delete-all-chunks request sent from Backend to knowledge_runtime.""" + """Purge-knowledge-index request - reference mode.""" knowledge_base_id: int - index_owner_user_id: int - retriever_config: RuntimeRetrieverConfig + user_id: int extensions: dict[str, Any] | None = None class RemoteDropKnowledgeIndexRequest(KnowledgeRuntimeProtocolModel): - """Drop-physical-index request sent from Backend to knowledge_runtime.""" + """Drop-physical-index request - reference mode.""" knowledge_base_id: int - index_owner_user_id: int - retriever_config: RuntimeRetrieverConfig + user_id: int extensions: dict[str, Any] | None = None class RemoteListChunksRequest(KnowledgeRuntimeProtocolModel): - """List-chunks request sent from Backend to knowledge_runtime.""" + """List-chunks request - reference mode.""" knowledge_base_id: int - index_owner_user_id: int - retriever_config: RuntimeRetrieverConfig + user_id: int max_chunks: int = Field(default=10000, gt=0, le=10000) query: str | None = None metadata_condition: dict[str, Any] | None = None extensions: dict[str, Any] | None = None -class RemoteTestConnectionRequest(KnowledgeRuntimeProtocolModel): - """Test-connection request sent from Backend to knowledge_runtime.""" - - retriever_config: RuntimeRetrieverConfig - extensions: dict[str, Any] | None = None - - class RemoteQueryRequest(KnowledgeRuntimeProtocolModel): - """Query request sent from Backend to knowledge_runtime.""" + """Query request - reference mode. KR resolves configs from DB.""" knowledge_base_ids: list[int] + user_id: int query: str max_results: int = Field(default=5, gt=0) document_ids: list[int] | None = None metadata_condition: dict[str, Any] | None = None - user_name: str | None = None - knowledge_base_configs: list[RemoteKnowledgeBaseQueryConfig] - enabled_index_families: list[str] = Field(default_factory=lambda: ["chunk_vector"]) - retrieval_policy: RetrievalPolicy = "chunk_only" extensions: dict[str, Any] | None = None - @model_validator(mode="after") - def validate_knowledge_base_configs(self) -> "RemoteQueryRequest": - if not self.knowledge_base_configs: - raise ValueError("knowledge_base_configs must not be empty") - - requested_ids = list(self.knowledge_base_ids) - configured_ids = [ - config.knowledge_base_id for config in self.knowledge_base_configs - ] - if Counter(requested_ids) != Counter(configured_ids): - raise ValueError( - "knowledge_base_configs must align with knowledge_base_ids" - ) - return self - class RemoteQueryRecord(KnowledgeRuntimeProtocolModel): """Single retrieval record returned by knowledge_runtime.""" diff --git a/shared/models/runtime_config.py b/shared/models/runtime_config.py new file mode 100644 index 000000000..84da3d90a --- /dev/null +++ b/shared/models/runtime_config.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: 2026 Weibo, Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Runtime configuration models for retriever, embedding, and retrieval settings.""" + +from __future__ import annotations + +from typing import Any, Literal + +from pydantic import BaseModel, ConfigDict, Field + + +class _RuntimeConfigModel(BaseModel): + """Base model for runtime configuration with strict field validation.""" + + model_config = ConfigDict(extra="forbid") + + +RetrievalMode = Literal["vector", "keyword", "hybrid"] + + +class RuntimeRetrieverConfig(_RuntimeConfigModel): + """Resolved retriever identity and storage configuration.""" + + name: str + namespace: str = "default" + storage_config: dict[str, Any] = Field(default_factory=dict) + + +class RuntimeEmbeddingModelConfig(_RuntimeConfigModel): + """Resolved embedding model configuration.""" + + model_name: str + model_namespace: str = "default" + resolved_config: dict[str, Any] = Field(default_factory=dict) + + +class RuntimeRetrievalConfig(_RuntimeConfigModel): + """Normalized retrieval config for a single knowledge base target.""" + + top_k: int = Field(default=20, gt=0) + score_threshold: float = Field(default=0.7, ge=0.0, le=1.0) + retrieval_mode: RetrievalMode = "vector" + vector_weight: float | None = Field(default=None, ge=0.0, le=1.0) + keyword_weight: float | None = Field(default=None, ge=0.0, le=1.0) diff --git a/shared/tests/test_knowledge_runtime_protocol.py b/shared/tests/test_knowledge_runtime_protocol.py index fd02da2bb..302b8e757 100644 --- a/shared/tests/test_knowledge_runtime_protocol.py +++ b/shared/tests/test_knowledge_runtime_protocol.py @@ -33,7 +33,6 @@ def test_shared_models_exports_knowledge_runtime_protocol_types() -> None: "RemoteQueryRequest", "RemoteQueryRecord", "RemoteQueryResponse", - "RemoteTestConnectionRequest", ] for name in exported_names: @@ -46,22 +45,8 @@ def test_remote_index_request_accepts_backend_attachment_stream_content_ref() -> request = remote_index_request.model_validate( { "knowledge_base_id": 11, + "user_id": 33, "document_id": 22, - "index_owner_user_id": 33, - "retriever_config": { - "name": "retriever-a", - "namespace": "default", - "storage_config": { - "type": "milvus", - "url": "http://milvus:19530", - }, - }, - "embedding_model_config": { - "model_name": "text-embedding-3-large", - "model_namespace": "default", - "resolved_config": {"protocol": "openai"}, - }, - "index_families": ["chunk_vector"], "content_ref": { "kind": "backend_attachment_stream", "url": "http://backend:8000/api/internal/rag/content/22", @@ -70,10 +55,29 @@ def test_remote_index_request_accepts_backend_attachment_stream_content_ref() -> } ) + assert request.knowledge_base_id == 11 + assert request.user_id == 33 assert request.content_ref.kind == "backend_attachment_stream" assert request.content_ref.auth_token == "test-token" +def test_remote_index_request_accepts_presigned_url_content_ref() -> None: + remote_index_request = _require_model("RemoteIndexRequest") + + request = remote_index_request.model_validate( + { + "knowledge_base_id": 11, + "user_id": 33, + "content_ref": { + "kind": "presigned_url", + "url": "https://storage.example.com/file.pdf", + }, + } + ) + + assert request.content_ref.kind == "presigned_url" + + def test_remote_index_request_rejects_unknown_content_ref_kind() -> None: remote_index_request = _require_model("RemoteIndexRequest") @@ -81,22 +85,7 @@ def test_remote_index_request_rejects_unknown_content_ref_kind() -> None: remote_index_request.model_validate( { "knowledge_base_id": 11, - "document_id": 22, - "index_owner_user_id": 33, - "retriever_config": { - "name": "retriever-a", - "namespace": "default", - "storage_config": { - "type": "milvus", - "url": "http://milvus:19530", - }, - }, - "embedding_model_config": { - "model_name": "text-embedding-3-large", - "model_namespace": "default", - "resolved_config": {"protocol": "openai"}, - }, - "index_families": ["chunk_vector"], + "user_id": 33, "content_ref": { "kind": "unsupported_kind", "url": "http://backend:8000/api/internal/rag/content/22", @@ -105,6 +94,22 @@ def test_remote_index_request_rejects_unknown_content_ref_kind() -> None: ) +def test_remote_index_request_rejects_missing_user_id() -> None: + remote_index_request = _require_model("RemoteIndexRequest") + + with pytest.raises(ValidationError): + remote_index_request.model_validate( + { + "knowledge_base_id": 11, + "content_ref": { + "kind": "backend_attachment_stream", + "url": "http://backend:8000/api/internal/rag/content/22", + "auth_token": "test-token", + }, + } + ) + + def test_remote_query_response_preserves_index_family_per_record() -> None: remote_query_response = _require_model("RemoteQueryResponse") @@ -136,21 +141,13 @@ def test_remote_query_response_preserves_index_family_per_record() -> None: ] -def test_remote_list_chunks_request_accepts_resolved_retriever_config() -> None: +def test_remote_list_chunks_request_accepts_reference_mode() -> None: remote_list_chunks_request = _require_model("RemoteListChunksRequest") request = remote_list_chunks_request.model_validate( { "knowledge_base_id": 1001, - "index_owner_user_id": 42, - "retriever_config": { - "name": "retriever-a", - "namespace": "default", - "storage_config": { - "type": "qdrant", - "url": "http://qdrant:6333", - }, - }, + "user_id": 42, "max_chunks": 1000, "query": "list_index_chunks", "metadata_condition": { @@ -163,7 +160,7 @@ def test_remote_list_chunks_request_accepts_resolved_retriever_config() -> None: ) assert request.knowledge_base_id == 1001 - assert request.retriever_config.storage_config["type"] == "qdrant" + assert request.user_id == 42 assert request.metadata_condition == { "operator": "and", "conditions": [ @@ -172,12 +169,13 @@ def test_remote_list_chunks_request_accepts_resolved_retriever_config() -> None: } -def test_remote_query_request_accepts_explicit_execution_configs() -> None: +def test_remote_query_request_accepts_reference_mode() -> None: remote_query_request = _require_model("RemoteQueryRequest") request = remote_query_request.model_validate( { "knowledge_base_ids": [1001], + "user_id": 42, "query": "release checklist", "max_results": 6, "metadata_condition": { @@ -187,46 +185,13 @@ def test_remote_query_request_accepts_explicit_execution_configs() -> None: {"key": "lang", "operator": "==", "value": "zh"}, ], }, - "knowledge_base_configs": [ - { - "knowledge_base_id": 1001, - "index_owner_user_id": 42, - "retriever_config": { - "name": "retriever-a", - "namespace": "default", - "storage_config": { - "type": "qdrant", - "url": "http://qdrant:6333", - "indexStrategy": {"mode": "per_dataset"}, - }, - }, - "embedding_model_config": { - "model_name": "embed-a", - "model_namespace": "default", - "resolved_config": { - "protocol": "openai", - "model_id": "text-embedding-3-small", - "base_url": "https://api.openai.com/v1", - }, - }, - "retrieval_config": { - "top_k": 8, - "score_threshold": 0.55, - "retrieval_mode": "hybrid", - "vector_weight": 0.8, - "keyword_weight": 0.2, - }, - } - ], - "enabled_index_families": ["chunk_vector", "summary_vector_index"], - "retrieval_policy": "summary_then_chunk_expand", } ) - assert ( - request.knowledge_base_configs[0].retriever_config.storage_config["type"] - == "qdrant" - ) + assert request.knowledge_base_ids == [1001] + assert request.user_id == 42 + assert request.query == "release checklist" + assert request.max_results == 6 assert request.metadata_condition == { "operator": "or", "conditions": [ @@ -234,14 +199,9 @@ def test_remote_query_request_accepts_explicit_execution_configs() -> None: {"key": "lang", "operator": "==", "value": "zh"}, ], } - assert request.knowledge_base_configs[0].retrieval_config.top_k == 8 - assert request.enabled_index_families == [ - "chunk_vector", - "summary_vector_index", - ] -def test_remote_query_request_rejects_empty_knowledge_base_configs() -> None: +def test_remote_query_request_rejects_missing_user_id() -> None: remote_query_request = _require_model("RemoteQueryRequest") with pytest.raises(ValidationError): @@ -249,131 +209,67 @@ def test_remote_query_request_rejects_empty_knowledge_base_configs() -> None: { "knowledge_base_ids": [1001], "query": "release checklist", - "knowledge_base_configs": [], } ) -def test_remote_query_request_rejects_misaligned_knowledge_base_configs() -> None: +def test_remote_query_request_rejects_extra_fields() -> None: + """Reference mode rejects legacy value-mode fields.""" remote_query_request = _require_model("RemoteQueryRequest") with pytest.raises(ValidationError): remote_query_request.model_validate( { - "knowledge_base_ids": [1001, 1002], - "query": "release checklist", - "knowledge_base_configs": [ - { - "knowledge_base_id": 1001, - "index_owner_user_id": 42, - "retriever_config": { - "name": "retriever-a", - "namespace": "default", - "storage_config": { - "type": "qdrant", - "url": "http://qdrant:6333", - }, - }, - "embedding_model_config": { - "model_name": "embed-a", - "model_namespace": "default", - "resolved_config": { - "protocol": "openai", - "model_id": "text-embedding-3-small", - }, - }, - "retrieval_config": { - "top_k": 8, - "score_threshold": 0.55, - "retrieval_mode": "hybrid", - }, - } - ], - } - ) - - -def test_remote_query_request_rejects_duplicate_alignment_mismatch() -> None: - remote_query_request = _require_model("RemoteQueryRequest") - - with pytest.raises(ValidationError): - remote_query_request.model_validate( - { - "knowledge_base_ids": [1001, 1001], + "knowledge_base_ids": [1001], + "user_id": 42, "query": "release checklist", - "knowledge_base_configs": [ - { - "knowledge_base_id": 1001, - "index_owner_user_id": 42, - "retriever_config": { - "name": "retriever-a", - "namespace": "default", - "storage_config": { - "type": "qdrant", - "url": "http://qdrant:6333", - }, - }, - "embedding_model_config": { - "model_name": "embed-a", - "model_namespace": "default", - "resolved_config": { - "protocol": "openai", - "model_id": "text-embedding-3-small", - }, - }, - "retrieval_config": { - "top_k": 8, - "score_threshold": 0.55, - "retrieval_mode": "vector", - }, - } - ], + "knowledge_base_configs": [], } ) -def test_remote_delete_request_requires_resolved_retriever_config() -> None: +def test_remote_delete_request_accepts_reference_mode() -> None: remote_delete_request = _require_model("RemoteDeleteDocumentIndexRequest") request = remote_delete_request.model_validate( { "knowledge_base_id": 101, + "user_id": 303, "document_ref": "202", - "index_owner_user_id": 303, - "retriever_config": { - "name": "retriever-a", - "namespace": "default", - "storage_config": { - "type": "elasticsearch", - "url": "http://es:9200", - "indexStrategy": {"mode": "per_user", "prefix": "wegent"}, - }, - }, - "enabled_index_families": ["chunk_vector"], } ) - assert request.retriever_config.name == "retriever-a" - assert request.retriever_config.storage_config["type"] == "elasticsearch" + assert request.knowledge_base_id == 101 + assert request.user_id == 303 + assert request.document_ref == "202" -def test_remote_test_connection_request_requires_retriever_config() -> None: - remote_test_connection_request = _require_model("RemoteTestConnectionRequest") +def test_remote_purge_knowledge_index_request_accepts_reference_mode() -> None: + remote_purge_request = _require_model("RemotePurgeKnowledgeIndexRequest") - request = remote_test_connection_request.model_validate( + request = remote_purge_request.model_validate( { - "retriever_config": { - "name": "retriever-a", - "namespace": "default", - "storage_config": { - "type": "qdrant", - "url": "http://qdrant:6333", - }, - } + "knowledge_base_id": 101, + "user_id": 42, } ) - assert request.retriever_config.storage_config["type"] == "qdrant" + assert request.knowledge_base_id == 101 + assert request.user_id == 42 + + +def test_remote_drop_knowledge_index_request_accepts_reference_mode() -> None: + remote_drop_request = _require_model("RemoteDropKnowledgeIndexRequest") + + request = remote_drop_request.model_validate( + { + "knowledge_base_id": 101, + "user_id": 42, + } + ) + + assert request.knowledge_base_id == 101 + assert request.user_id == 42 @pytest.mark.parametrize( @@ -399,44 +295,16 @@ def test_remote_test_connection_request_requires_retriever_config() -> None: "RemoteQueryRequest", { "knowledge_base_ids": [1], + "user_id": 42, "query": "release", "max_results": 0, - "knowledge_base_configs": [ - { - "knowledge_base_id": 1, - "index_owner_user_id": 42, - "retriever_config": { - "name": "retriever-a", - "namespace": "default", - "storage_config": { - "type": "qdrant", - "url": "http://qdrant:6333", - }, - }, - "embedding_model_config": { - "model_name": "embed-a", - "model_namespace": "default", - "resolved_config": {"protocol": "openai"}, - }, - "retrieval_config": { - "top_k": 8, - "score_threshold": 0.55, - "retrieval_mode": "vector", - }, - } - ], }, ), ( "RemoteListChunksRequest", { "knowledge_base_id": 1001, - "index_owner_user_id": 42, - "retriever_config": { - "name": "retriever-a", - "namespace": "default", - "storage_config": {"type": "qdrant", "url": "http://qdrant:6333"}, - }, + "user_id": 42, "max_chunks": 10001, }, ), diff --git a/shared/utils/placeholder.py b/shared/utils/placeholder.py new file mode 100644 index 000000000..3bd09f179 --- /dev/null +++ b/shared/utils/placeholder.py @@ -0,0 +1,138 @@ +# SPDX-FileCopyrightText: 2026 Weibo, Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Placeholder replacement utilities for custom headers and configuration. + +Shared by Backend (embedding factory, model resolver) and Knowledge Runtime +(config resolver) for processing ${...} placeholder patterns in headers and +configuration dictionaries. +""" + +from __future__ import annotations + +import logging +import re +from typing import Any + +logger = logging.getLogger(__name__) + + +def resolve_value_from_source( + data_sources: dict[str, dict[str, Any]], source_spec: str +) -> str: + """Resolve value from specified data source using flexible notation. + + Supports dot-notation paths like "user.name" within a data source. + + Args: + data_sources: Dictionary containing all available data sources. + source_spec: Source specification in format "source_name.path" or just "path". + + Returns: + The resolved value as a string, or empty string if not found. + """ + try: + if "." in source_spec: + parts = source_spec.split(".", 1) + source_name = parts[0] + path = parts[1] + else: + source_name = "agent_config" + path = source_spec + + if source_name not in data_sources: + return "" + + data = data_sources[source_name] + keys = path.split(".") + current = data + + for key in keys: + if isinstance(current, dict) and key in current: + current = current[key] + elif ( + isinstance(current, list) and key.isdigit() and int(key) < len(current) + ): + current = current[int(key)] + else: + return "" + + return str(current) if current is not None else "" + except Exception: + return "" + + +def replace_placeholders_with_sources( + template: str, data_sources: dict[str, dict[str, Any]] +) -> str: + """Replace placeholders in template with values from multiple data sources. + + Args: + template: The template string with placeholders like ${source.path}. + data_sources: Dictionary containing all available data sources. + + Returns: + The template with placeholders replaced with actual values. + """ + pattern = r"\$\{([^}]+)\}" + + def replace_match(match: re.Match) -> str: + source_spec = match.group(1) + value = resolve_value_from_source(data_sources, source_spec) + return value + + return re.sub(pattern, replace_match, template) + + +def build_headers_with_placeholders( + headers: dict[str, Any], data_sources: dict[str, dict[str, Any]] +) -> dict[str, Any]: + """Build headers dict with placeholder replacement on string values. + + Args: + headers: Raw headers dictionary (may contain placeholders). + data_sources: Dictionary containing all available data sources. + + Returns: + Headers with placeholders replaced. + """ + result: dict[str, Any] = {} + try: + for k, v in headers.items(): + if isinstance(v, str): + result[k] = replace_placeholders_with_sources(v, data_sources) + else: + result[k] = v + except Exception as e: + logger.warning( + "Failed to build headers with placeholders; proceeding without. Error: %s", + e, + ) + return {} + return result + + +def process_custom_headers_placeholders( + custom_headers: dict[str, Any], + user_name: str | None = None, +) -> dict[str, Any]: + """Process placeholders in custom headers. + + Supports placeholder format: ${user.name} + + Args: + custom_headers: Custom headers dict (may contain placeholders). + user_name: User name for placeholder replacement. + + Returns: + Custom headers with placeholders replaced. + """ + if not custom_headers or not isinstance(custom_headers, dict): + return custom_headers + + data_sources: dict[str, dict[str, Any]] = { + "user": {"name": user_name or ""}, + } + + return build_headers_with_placeholders(custom_headers, data_sources)