Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 135 additions & 44 deletions backend/app/services/rag/remote_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
#
# SPDX-License-Identifier: Apache-2.0

"""Remote RAG Gateway for Knowledge Runtime.

This gateway sends requests to Knowledge Runtime using reference mode:
only passes references (user_id + kb_id/retriever_name), not full configurations.
The Knowledge Runtime resolves full configurations from the database.
"""

from __future__ import annotations

from typing import Any
Expand All @@ -11,8 +18,6 @@
from sqlalchemy.orm import Session

from app.core.config import settings
from app.models.subtask_context import ContextType
from app.services.context import context_service
from app.services.rag.content_refs import build_content_ref_for_attachment
from app.services.rag.runtime_specs import (
ConnectionTestRuntimeSpec,
Expand All @@ -24,6 +29,7 @@
QueryRuntimeSpec,
)
from shared.models import (
KnowledgeBaseReference,
RemoteDeleteDocumentIndexRequest,
RemoteDropKnowledgeIndexRequest,
RemoteIndexRequest,
Expand All @@ -34,6 +40,7 @@
RemoteQueryResponse,
RemoteRagError,
RemoteTestConnectionRequest,
RetrieverReference,
)


Expand Down Expand Up @@ -137,37 +144,36 @@ async def index_document(
*,
db: Session | None = None,
) -> dict[str, Any]:
"""Index a document using reference mode.

Args:
spec: Index runtime spec with KB reference info.
db: Database session (required for content ref resolution).

Returns:
Indexing result.
"""
if db is None:
raise ValueError("db is required for RemoteRagGateway.index_document")
if spec.source.source_type != "attachment" or spec.source.attachment_id is None:
raise ValueError("RemoteRagGateway only supports attachment sources")

source_file, file_extension = _get_attachment_source_metadata(
db=db,
attachment_id=spec.source.attachment_id,
# Build KB reference - use index_owner_user_id for resolving config
kb_reference = KnowledgeBaseReference(
knowledge_base_id=spec.knowledge_base_id,
user_id=spec.index_owner_user_id,
)

payload = RemoteIndexRequest(
knowledge_base_id=spec.knowledge_base_id,
document_id=spec.document_id,
index_owner_user_id=spec.index_owner_user_id,
retriever_config=spec.retriever_config
or {
"name": spec.retriever_name,
"namespace": spec.retriever_namespace,
},
embedding_model_config=spec.embedding_model_config
or {
"model_name": spec.embedding_model_name,
"model_namespace": spec.embedding_model_namespace,
},
splitter_config=spec.splitter_config,
index_families=spec.index_families,
content_ref=build_content_ref_for_attachment(
db=db,
attachment_id=spec.source.attachment_id,
),
source_file=source_file,
file_extension=file_extension,
knowledge_base_reference=kb_reference,
splitter_config=spec.splitter_config,
index_families=spec.index_families,
user_name=spec.user_name,
)
return await self._post_model("/internal/rag/index", payload)
Expand All @@ -178,15 +184,36 @@ async def query(
*,
db: Session | None = None,
) -> dict[str, Any]:
"""Execute a RAG query using reference mode.

Args:
spec: Query runtime spec with KB references.
db: Database session (not used, kept for interface consistency).

Returns:
Query result with records.
"""
del db

# Build KB references from knowledge_base_configs
# Each KB config contains the index_owner_user_id needed for reference
kb_references = [
KnowledgeBaseReference(
knowledge_base_id=kb_config.knowledge_base_id,
user_id=kb_config.index_owner_user_id,
)
for kb_config in spec.knowledge_base_configs
]

payload = RemoteQueryRequest(
knowledge_base_ids=spec.knowledge_base_ids,
query=spec.query,
max_results=spec.max_results,
knowledge_base_references=kb_references,
user_id=spec.user_id or 0,
document_ids=spec.document_ids,
Comment on lines 208 to 214
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fail fast instead of defaulting missing user_id to 0.

spec.user_id or 0 silently turns an omitted user context into the public/system user context, which can resolve the wrong references. Require user_id explicitly for remote reference-mode calls.

Suggested fix
+        if spec.user_id is None:
+            raise ValueError("user_id is required for remote query reference mode")
+
         payload = RemoteQueryRequest(
             knowledge_base_ids=spec.knowledge_base_ids,
             query=spec.query,
             max_results=spec.max_results,
             knowledge_base_references=kb_references,
-            user_id=spec.user_id or 0,
+            user_id=spec.user_id,
             document_ids=spec.document_ids,
             metadata_condition=spec.metadata_condition,
             user_name=spec.user_name,
             enabled_index_families=spec.enabled_index_families,
             retrieval_policy=spec.retrieval_policy,
@@
+        if spec.user_id is None:
+            raise ValueError("user_id is required for remote test_connection reference mode")
+
         # Build Retriever reference
         retriever_reference = RetrieverReference(
             name=spec.retriever_name,
             namespace=spec.retriever_namespace,
-            user_id=spec.user_id or 0,
+            user_id=spec.user_id,
         )

Also applies to: 367-375

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@backend/app/services/rag/remote_gateway.py` around lines 208 - 214, The
payload construction in RemoteQueryRequest currently defaults missing
spec.user_id to 0 (user_id=spec.user_id or 0), which can silently use a
system/public user; update the code in remote_gateway.py where
RemoteQueryRequest is built (and the analogous block around the second
occurrence) to require an explicit user_id: validate that spec.user_id is
present and raise an informative exception (or return an error) when it is
missing instead of using 0, and then pass spec.user_id directly into
RemoteQueryRequest.user_id; ensure the validation runs only for remote
reference-mode calls if that context is available.

metadata_condition=spec.metadata_condition,
user_name=spec.user_name,
knowledge_base_configs=spec.knowledge_base_configs,
enabled_index_families=spec.enabled_index_families,
retrieval_policy=spec.retrieval_policy,
)
Expand All @@ -203,12 +230,27 @@ async def delete_document_index(
*,
db: Session,
) -> dict[str, Any]:
"""Delete a document's index using reference mode.

Args:
spec: Delete runtime spec with KB reference.
db: Database session (not used, kept for interface consistency).

Returns:
Deletion result.
"""
del db

# Build KB reference
kb_reference = KnowledgeBaseReference(
knowledge_base_id=spec.knowledge_base_id,
user_id=spec.index_owner_user_id,
)

payload = RemoteDeleteDocumentIndexRequest(
knowledge_base_id=spec.knowledge_base_id,
document_ref=spec.document_ref,
index_owner_user_id=spec.index_owner_user_id,
retriever_config=spec.retriever_config,
knowledge_base_reference=kb_reference,
enabled_index_families=spec.enabled_index_families,
)
return await self._post_model("/internal/rag/delete-document-index", payload)
Expand All @@ -219,11 +261,26 @@ async def purge_knowledge_index(
*,
db: Session,
) -> dict[str, Any]:
"""Purge all chunks for a knowledge base using reference mode.

Args:
spec: Purge runtime spec with KB reference.
db: Database session (not used, kept for interface consistency).

Returns:
Purge result.
"""
del db

# Build KB reference
kb_reference = KnowledgeBaseReference(
knowledge_base_id=spec.knowledge_base_id,
user_id=spec.index_owner_user_id,
)

payload = RemotePurgeKnowledgeIndexRequest(
knowledge_base_id=spec.knowledge_base_id,
index_owner_user_id=spec.index_owner_user_id,
retriever_config=spec.retriever_config,
knowledge_base_reference=kb_reference,
)
return await self._post_model("/internal/rag/purge-knowledge-index", payload)

Expand All @@ -233,11 +290,26 @@ async def drop_knowledge_index(
*,
db: Session,
) -> dict[str, Any]:
"""Drop the physical index for a knowledge base using reference mode.

Args:
spec: Drop runtime spec with KB reference.
db: Database session (not used, kept for interface consistency).

Returns:
Drop result.
"""
del db

# Build KB reference
kb_reference = KnowledgeBaseReference(
knowledge_base_id=spec.knowledge_base_id,
user_id=spec.index_owner_user_id,
)

payload = RemoteDropKnowledgeIndexRequest(
knowledge_base_id=spec.knowledge_base_id,
index_owner_user_id=spec.index_owner_user_id,
retriever_config=spec.retriever_config,
knowledge_base_reference=kb_reference,
)
return await self._post_model("/internal/rag/drop-knowledge-index", payload)

Expand All @@ -247,11 +319,26 @@ async def list_chunks(
*,
db: Session | None = None,
) -> dict[str, Any]:
"""List chunks for a knowledge base using reference mode.

Args:
spec: List chunks runtime spec with KB reference.
db: Database session (not used, kept for interface consistency).

Returns:
List of chunks.
"""
del db

# Build KB reference
kb_reference = KnowledgeBaseReference(
knowledge_base_id=spec.knowledge_base_id,
user_id=spec.index_owner_user_id,
)

payload = RemoteListChunksRequest(
knowledge_base_id=spec.knowledge_base_id,
index_owner_user_id=spec.index_owner_user_id,
retriever_config=spec.retriever_config,
knowledge_base_reference=kb_reference,
max_chunks=spec.max_chunks,
query=spec.query,
metadata_condition=spec.metadata_condition,
Expand All @@ -266,21 +353,25 @@ async def test_connection(
*,
db: Session | None = None,
) -> dict[str, Any]:
del db
payload = RemoteTestConnectionRequest(retriever_config=spec.retriever_config)
return await self._post_model("/internal/rag/test-connection", payload)
"""Test storage backend connection using reference mode.

Args:
spec: Connection test runtime spec with Retriever reference.
db: Database session (not used, kept for interface consistency).

def _get_attachment_source_metadata(
*,
db: Session,
attachment_id: int,
) -> tuple[str | None, str | None]:
context = context_service.get_context_optional(
db=db,
context_id=attachment_id,
)
if context is None or context.context_type != ContextType.ATTACHMENT.value:
return None, None
Returns:
Connection test result.
"""
del db

# Build Retriever reference
retriever_reference = RetrieverReference(
name=spec.retriever_name,
namespace=spec.retriever_namespace,
user_id=spec.user_id or 0,
)

return context.original_filename or None, context.file_extension or None
payload = RemoteTestConnectionRequest(
retriever_reference=retriever_reference,
)
return await self._post_model("/internal/rag/test-connection", payload)
17 changes: 16 additions & 1 deletion backend/app/services/rag/runtime_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,22 @@ class ListChunksRuntimeSpec(RuntimeSpecModel):


class ConnectionTestRuntimeSpec(RuntimeSpecModel):
retriever_config: RuntimeRetrieverConfig
"""Runtime spec for connection testing.

For RemoteGateway (reference mode):
- Use retriever_name, retriever_namespace, user_id to resolve config

For LocalGateway (full config mode):
- Use retriever_config directly
"""

# Reference mode fields (for RemoteGateway)
retriever_name: Optional[str] = None
retriever_namespace: str = "default"
user_id: Optional[int] = None

# Full config mode field (for LocalGateway)
retriever_config: Optional[RuntimeRetrieverConfig] = None


DEFAULT_DIRECT_INJECTION_BUDGET = DirectInjectionBudget()
Loading
Loading