Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backend/app/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from app.api.endpoints.internal import (
callback_router,
chat_storage_router,
rag_content_router,
services_router,
skills_router,
subscriptions_router,
Expand Down Expand Up @@ -195,6 +196,9 @@
api_router.include_router(
chat_storage_router, prefix="/internal", tags=["internal-chat"]
)
api_router.include_router(
rag_content_router, prefix="/internal", tags=["internal-rag-content"]
)

# RAG internal router is conditionally registered based on STANDALONE_MODE
if not settings.STANDALONE_MODE:
Expand Down
51 changes: 27 additions & 24 deletions backend/app/api/endpoints/adapter/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import logging
from typing import TYPE_CHECKING, Optional
from typing import Optional

from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
Expand All @@ -14,6 +14,11 @@
from app.models.user import User
from app.schemas.kind import Retriever
from app.services.adapters.retriever_kinds import retriever_kinds_service
from app.services.rag.gateway_factory import get_query_gateway
from app.services.rag.local_gateway import LocalRagGateway
from app.services.rag.remote_gateway import RemoteRagGatewayError
from app.services.rag.runtime_specs import ConnectionTestRuntimeSpec
from shared.models import RuntimeRetrieverConfig

# RAG storage factory is conditionally imported based on STANDALONE_MODE
# RAG module is heavy (llama_index, scipy, pandas, grpc) - skip in standalone mode
Expand Down Expand Up @@ -213,7 +218,7 @@ def delete_retriever(


@router.post("/test-connection")
def test_retriever_connection(
async def test_retriever_connection(
test_data: dict,
current_user: User = Depends(security.get_current_user),
):
Expand All @@ -236,6 +241,7 @@ def test_retriever_connection(
}
"""
_check_rag_available()
del current_user

storage_type = test_data.get("storage_type")
url = test_data.get("url")
Expand All @@ -250,31 +256,28 @@ def test_retriever_connection(
}

try:
# Create storage backend from config
backend = storage_factory.create_storage_backend_from_config(
storage_type=storage_type,
url=url,
username=username,
password=password,
api_key=api_key,
runtime_spec = ConnectionTestRuntimeSpec(
retriever_config=RuntimeRetrieverConfig(
name="connection-test",
namespace="default",
storage_config={
"type": storage_type,
"url": url,
"username": username,
"password": password,
"apiKey": api_key,
"indexStrategy": {"mode": "per_dataset"},
"ext": {},
},
)
)

# Test connection using backend's test_connection method
success = backend.test_connection()

if success:
return {
"success": True,
"message": f"Successfully connected to {storage_type}",
}
else:
return {
"success": False,
"message": f"Failed to connect to {storage_type}",
}
gateway = get_query_gateway()
try:
return await gateway.test_connection(runtime_spec)
except RemoteRagGatewayError:
return await LocalRagGateway().test_connection(runtime_spec)

except ValueError as e:
# Unsupported storage type
return {"success": False, "message": str(e)}

except Exception as e:
Expand Down
2 changes: 2 additions & 0 deletions backend/app/api/endpoints/internal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .bots import router as bots_router
from .callback import router as callback_router
from .chat_storage import router as chat_storage_router
from .rag_content import router as rag_content_router
from .services import router as services_router
from .skills import router as skills_router
from .subscriptions import router as subscriptions_router
Expand All @@ -24,6 +25,7 @@
"bots_router",
"callback_router",
"chat_storage_router",
"rag_content_router",
"services_router",
"skills_router",
"subscriptions_router",
Expand Down
56 changes: 51 additions & 5 deletions backend/app/api/endpoints/internal/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
from app.services.knowledge.retrieval_persistence import (
retrieval_persistence_service,
)
from app.services.rag.gateway_factory import get_query_gateway
from app.services.rag.local_gateway import LocalRagGateway
from app.services.rag.remote_gateway import RemoteRagGatewayError
from app.services.rag.retrieval_service import RetrievalService
from app.services.rag.runtime_resolver import RagRuntimeResolver

Expand All @@ -36,7 +38,6 @@

router = APIRouter(prefix="/rag", tags=["internal-rag"])
runtime_resolver = RagRuntimeResolver()
rag_gateway = LocalRagGateway()


class DirectInjectionRuntimeContext(BaseModel):
Expand Down Expand Up @@ -166,6 +167,52 @@ class InternalRetrieveResponse(BaseModel):
total_estimated_tokens: int = 0


def _resolve_query_gateway(runtime_spec):
route_mode = getattr(runtime_spec, "route_mode", "auto")
if route_mode == "rag_retrieval":
return get_query_gateway()
return LocalRagGateway()


def _finalize_query_runtime_spec(runtime_spec, db: Session):
if getattr(runtime_spec, "route_mode", "auto") != "auto":
return runtime_spec
required_attributes = (
"query",
"knowledge_base_ids",
"document_ids",
"direct_injection_budget",
"model_copy",
)
if not all(hasattr(runtime_spec, attr) for attr in required_attributes):
return runtime_spec

retrieval_service = RetrievalService()
budget = getattr(runtime_spec, "direct_injection_budget", None)
resolved_route_mode = retrieval_service.decide_route_mode_for_chat_shell(
query=runtime_spec.query,
knowledge_base_ids=runtime_spec.knowledge_base_ids,
db=db,
route_mode=runtime_spec.route_mode,
document_ids=runtime_spec.document_ids,
context_window=budget.context_window if budget else None,
)
return runtime_spec.model_copy(update={"route_mode": resolved_route_mode})


async def _execute_query_with_remote_fallback(runtime_spec, db: Session):
rag_gateway = _resolve_query_gateway(runtime_spec)
try:
return await rag_gateway.query(runtime_spec, db=db)
except RemoteRagGatewayError as exc:
logger.warning(
"[internal_rag] Remote query failed for KBs %s, falling back to local gateway: %s",
getattr(runtime_spec, "knowledge_base_ids", []),
exc,
)
return await LocalRagGateway().query(runtime_spec, db=db)


@router.post(
"/retrieve",
response_model=InternalRetrieveResponse | ProtectedKnowledgeMediationResponse,
Expand Down Expand Up @@ -207,6 +254,7 @@ async def internal_retrieve(
)

runtime_spec = runtime_resolver.build_query_runtime_spec(
db=db,
knowledge_base_ids=knowledge_base_ids,
query=request.query,
max_results=request.max_results,
Expand All @@ -229,10 +277,8 @@ async def internal_retrieve(
),
restricted_mode=restricted_mode,
)
result = await rag_gateway.query(
runtime_spec,
db=db,
)
runtime_spec = _finalize_query_runtime_spec(runtime_spec, db)
result = await _execute_query_with_remote_fallback(runtime_spec, db)

records = result.get("records", [])

Expand Down
90 changes: 90 additions & 0 deletions backend/app/api/endpoints/internal/rag_content.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# SPDX-FileCopyrightText: 2026 Weibo, Inc.
#
# SPDX-License-Identifier: Apache-2.0

"""Internal attachment streaming endpoint for knowledge_runtime."""

from __future__ import annotations

import logging
from collections.abc import Iterator

from fastapi import APIRouter, Depends, Header, HTTPException, status
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session

from app.api.dependencies import get_db
from app.api.endpoints.adapter.attachments import _build_content_disposition
from app.models.subtask_context import ContextStatus, ContextType
from app.services.auth import extract_token_from_header
from app.services.auth.rag_download_token import verify_rag_download_token
from app.services.context import context_service

logger = logging.getLogger(__name__)

router = APIRouter(prefix="/rag/content", tags=["internal-rag-content"])


def _binary_stream(binary_data: bytes) -> Iterator[bytes]:
yield binary_data
Comment on lines +28 to +29
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

This isn't actually streaming yet.

backend/app/services/context/context_service.py returns bytes from get_attachment_binary_data(), so Line 70 has already materialized the whole attachment and Line 85 just yields that buffer once. Large attachments will therefore spike Backend memory on the new remote-indexing path; this needs a chunked/file-like reader from the storage layer instead of wrapping a single bytes object.

Also applies to: 70-85

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@backend/app/api/endpoints/internal/rag_content.py` around lines 28 - 29, The
current _binary_stream yields a single bytes buffer because
get_attachment_binary_data() materializes the entire attachment; change the
streaming contract so get_attachment_binary_data() (in ContextService) returns a
chunked iterator or file-like object and update _binary_stream to iterate/read
from that stream in fixed-size chunks (e.g., 8KB) and yield each chunk, and
update the caller in rag_content.py to pass the new iterator/file-like instead
of a bytes object; ensure any storage backends implement the new
stream-returning method so large attachments are streamed rather than loaded
fully into memory.



def _verify_rag_download_authorization(
attachment_id: int,
authorization: str = Header(default=""),
) -> None:
"""Validate Bearer token for internal RAG attachment download."""

token = extract_token_from_header(authorization)
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing or invalid Authorization header",
)

token_info = verify_rag_download_token(token)
if token_info is None or token_info.attachment_id != attachment_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid RAG download token",
)


@router.get("/{attachment_id}")
async def stream_rag_attachment_content(
attachment_id: int,
_: None = Depends(_verify_rag_download_authorization),
db: Session = Depends(get_db),
):
"""Stream attachment binary content for knowledge_runtime."""

context = context_service.get_context_optional(
db=db,
context_id=attachment_id,
)
if context is None or context.context_type != ContextType.ATTACHMENT.value:
raise HTTPException(status_code=404, detail="Attachment not found")
if context.status != ContextStatus.READY.value:
raise HTTPException(status_code=409, detail="Attachment is not ready")

binary_data = context_service.get_attachment_binary_data(
db=db,
context=context,
)
if binary_data is None:
logger.error(
"Failed to retrieve binary data for internal RAG attachment %s",
attachment_id,
)
raise HTTPException(
status_code=500,
detail="Failed to retrieve attachment data",
)

return StreamingResponse(
_binary_stream(binary_data),
media_type=context.mime_type,
headers={
"Content-Disposition": _build_content_disposition(context.original_filename)
},
)
Loading
Loading