Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backend/app/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from app.api.endpoints.internal import (
callback_router,
chat_storage_router,
rag_content_router,
services_router,
skills_router,
subscriptions_router,
Expand Down Expand Up @@ -195,6 +196,9 @@
api_router.include_router(
chat_storage_router, prefix="/internal", tags=["internal-chat"]
)
api_router.include_router(
rag_content_router, prefix="/internal", tags=["internal-rag-content"]
)

# RAG internal router is conditionally registered based on STANDALONE_MODE
if not settings.STANDALONE_MODE:
Expand Down
51 changes: 27 additions & 24 deletions backend/app/api/endpoints/adapter/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import logging
from typing import TYPE_CHECKING, Optional
from typing import Optional

from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
Expand All @@ -14,6 +14,11 @@
from app.models.user import User
from app.schemas.kind import Retriever
from app.services.adapters.retriever_kinds import retriever_kinds_service
from app.services.rag.gateway_factory import get_query_gateway
from app.services.rag.local_gateway import LocalRagGateway
from app.services.rag.remote_gateway import RemoteRagGatewayError
from app.services.rag.runtime_specs import ConnectionTestRuntimeSpec
from shared.models import RuntimeRetrieverConfig

# RAG storage factory is conditionally imported based on STANDALONE_MODE
# RAG module is heavy (llama_index, scipy, pandas, grpc) - skip in standalone mode
Expand Down Expand Up @@ -213,7 +218,7 @@ def delete_retriever(


@router.post("/test-connection")
def test_retriever_connection(
async def test_retriever_connection(
test_data: dict,
current_user: User = Depends(security.get_current_user),
):
Expand All @@ -236,6 +241,7 @@ def test_retriever_connection(
}
"""
_check_rag_available()
del current_user

storage_type = test_data.get("storage_type")
url = test_data.get("url")
Expand All @@ -250,31 +256,28 @@ def test_retriever_connection(
}

try:
# Create storage backend from config
backend = storage_factory.create_storage_backend_from_config(
storage_type=storage_type,
url=url,
username=username,
password=password,
api_key=api_key,
runtime_spec = ConnectionTestRuntimeSpec(
retriever_config=RuntimeRetrieverConfig(
name="connection-test",
namespace="default",
storage_config={
"type": storage_type,
"url": url,
"username": username,
"password": password,
"apiKey": api_key,
"indexStrategy": {"mode": "per_dataset"},
"ext": {},
},
)
)

# Test connection using backend's test_connection method
success = backend.test_connection()

if success:
return {
"success": True,
"message": f"Successfully connected to {storage_type}",
}
else:
return {
"success": False,
"message": f"Failed to connect to {storage_type}",
}
gateway = get_query_gateway()
try:
return await gateway.test_connection(runtime_spec)
except RemoteRagGatewayError:
return await LocalRagGateway().test_connection(runtime_spec)

except ValueError as e:
# Unsupported storage type
return {"success": False, "message": str(e)}

except Exception as e:
Expand Down
2 changes: 2 additions & 0 deletions backend/app/api/endpoints/internal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .bots import router as bots_router
from .callback import router as callback_router
from .chat_storage import router as chat_storage_router
from .rag_content import router as rag_content_router
from .services import router as services_router
from .skills import router as skills_router
from .subscriptions import router as subscriptions_router
Expand All @@ -24,6 +25,7 @@
"bots_router",
"callback_router",
"chat_storage_router",
"rag_content_router",
"services_router",
"skills_router",
"subscriptions_router",
Expand Down
56 changes: 51 additions & 5 deletions backend/app/api/endpoints/internal/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
from app.services.knowledge.retrieval_persistence import (
retrieval_persistence_service,
)
from app.services.rag.gateway_factory import get_query_gateway
from app.services.rag.local_gateway import LocalRagGateway
from app.services.rag.remote_gateway import RemoteRagGatewayError
from app.services.rag.retrieval_service import RetrievalService
from app.services.rag.runtime_resolver import RagRuntimeResolver

Expand All @@ -36,7 +38,6 @@

router = APIRouter(prefix="/rag", tags=["internal-rag"])
runtime_resolver = RagRuntimeResolver()
rag_gateway = LocalRagGateway()


class DirectInjectionRuntimeContext(BaseModel):
Expand Down Expand Up @@ -166,6 +167,52 @@ class InternalRetrieveResponse(BaseModel):
total_estimated_tokens: int = 0


def _resolve_query_gateway(runtime_spec):
route_mode = getattr(runtime_spec, "route_mode", "auto")
if route_mode == "rag_retrieval":
return get_query_gateway()
return LocalRagGateway()


def _finalize_query_runtime_spec(runtime_spec, db: Session):
if getattr(runtime_spec, "route_mode", "auto") != "auto":
return runtime_spec
required_attributes = (
"query",
"knowledge_base_ids",
"document_ids",
"direct_injection_budget",
"model_copy",
)
if not all(hasattr(runtime_spec, attr) for attr in required_attributes):
return runtime_spec

retrieval_service = RetrievalService()
budget = getattr(runtime_spec, "direct_injection_budget", None)
resolved_route_mode = retrieval_service.decide_route_mode_for_chat_shell(
query=runtime_spec.query,
knowledge_base_ids=runtime_spec.knowledge_base_ids,
db=db,
route_mode=runtime_spec.route_mode,
document_ids=runtime_spec.document_ids,
context_window=budget.context_window if budget else None,
)
return runtime_spec.model_copy(update={"route_mode": resolved_route_mode})


async def _execute_query_with_remote_fallback(runtime_spec, db: Session):
rag_gateway = _resolve_query_gateway(runtime_spec)
try:
return await rag_gateway.query(runtime_spec, db=db)
except RemoteRagGatewayError as exc:
logger.warning(
"[internal_rag] Remote query failed for KBs %s, falling back to local gateway: %s",
getattr(runtime_spec, "knowledge_base_ids", []),
exc,
)
return await LocalRagGateway().query(runtime_spec, db=db)


@router.post(
"/retrieve",
response_model=InternalRetrieveResponse | ProtectedKnowledgeMediationResponse,
Expand Down Expand Up @@ -207,6 +254,7 @@ async def internal_retrieve(
)

runtime_spec = runtime_resolver.build_query_runtime_spec(
db=db,
knowledge_base_ids=knowledge_base_ids,
query=request.query,
max_results=request.max_results,
Expand All @@ -229,10 +277,8 @@ async def internal_retrieve(
),
restricted_mode=restricted_mode,
)
result = await rag_gateway.query(
runtime_spec,
db=db,
)
runtime_spec = _finalize_query_runtime_spec(runtime_spec, db)
result = await _execute_query_with_remote_fallback(runtime_spec, db)

records = result.get("records", [])

Expand Down
85 changes: 85 additions & 0 deletions backend/app/api/endpoints/internal/rag_content.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# SPDX-FileCopyrightText: 2026 Weibo, Inc.
#
# SPDX-License-Identifier: Apache-2.0

"""Internal attachment streaming endpoint for knowledge_runtime."""

from __future__ import annotations

import logging

from fastapi import APIRouter, Depends, Header, HTTPException, status
from fastapi.responses import Response
from sqlalchemy.orm import Session

from app.api.dependencies import get_db
from app.api.endpoints.adapter.attachments import _build_content_disposition
from app.models.subtask_context import ContextStatus, ContextType
from app.services.auth import extract_token_from_header
from app.services.auth.rag_download_token import verify_rag_download_token
from app.services.context import context_service

logger = logging.getLogger(__name__)

router = APIRouter(prefix="/rag/content", tags=["internal-rag-content"])


def _verify_rag_download_authorization(
attachment_id: int,
authorization: str = Header(default=""),
) -> None:
"""Validate Bearer token for internal RAG attachment download."""

token = extract_token_from_header(authorization)
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing or invalid Authorization header",
)

token_info = verify_rag_download_token(token)
if token_info is None or token_info.attachment_id != attachment_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid RAG download token",
)


@router.get("/{attachment_id}")
async def stream_rag_attachment_content(
attachment_id: int,
_: None = Depends(_verify_rag_download_authorization),
db: Session = Depends(get_db),
):
"""Stream attachment binary content for knowledge_runtime."""

context = context_service.get_context_optional(
db=db,
context_id=attachment_id,
)
if context is None or context.context_type != ContextType.ATTACHMENT.value:
raise HTTPException(status_code=404, detail="Attachment not found")
if context.status != ContextStatus.READY.value:
raise HTTPException(status_code=409, detail="Attachment is not ready")

binary_data = context_service.get_attachment_binary_data(
db=db,
context=context,
)
if binary_data is None:
logger.error(
"Failed to retrieve binary data for internal RAG attachment %s",
attachment_id,
)
raise HTTPException(
status_code=500,
detail="Failed to retrieve attachment data",
)

return Response(
content=binary_data,
media_type=context.mime_type,
headers={
"Content-Disposition": _build_content_disposition(context.original_filename)
},
)
Loading
Loading