diff --git a/0 b/0 new file mode 100644 index 0000000..db99733 Binary files /dev/null and b/0 differ diff --git a/backend/app/__pycache__/__init__.cpython-313.pyc b/backend/app/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index 8f98708..0000000 Binary files a/backend/app/__pycache__/__init__.cpython-313.pyc and /dev/null differ diff --git a/backend/app/api/__pycache__/__init__.cpython-313.pyc b/backend/app/api/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index ed80ef2..0000000 Binary files a/backend/app/api/__pycache__/__init__.cpython-313.pyc and /dev/null differ diff --git a/backend/app/api/attachment_api.py b/backend/app/api/attachment_api.py new file mode 100644 index 0000000..ab70759 --- /dev/null +++ b/backend/app/api/attachment_api.py @@ -0,0 +1,183 @@ +import base64 +import logging +from fastapi import APIRouter, UploadFile, File +from fastapi.responses import JSONResponse + +from app.core.config import settings + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/upload", tags=["upload"]) + + +def _extract_pdf_text(file_bytes: bytes) -> tuple[str | None, str]: + """尝试提取 PDF 文本内容 + + Returns: + tuple: (提取的文本或None, 错误信息) + """ + try: + import fitz + import io + except ImportError: + return None, "PyMuPDF库未安装,无法解析PDF文件" + + try: + doc = fitz.open(stream=io.BytesIO(file_bytes), filetype="pdf") + if doc.is_closed: + return None, "PDF文件格式错误,无法打开" + text = "" + for page_num, page in enumerate(doc): + try: + page_text = page.get_text() + text += page_text + except Exception as e: + logger.warning(f"[PDF] 第{page_num + 1}页提取失败: {e}") + continue + doc.close() + + if not text.strip(): + return None, "PDF文件中没有可提取的文本内容(可能是扫描件或图片型PDF)" + return text.strip(), "" + except Exception as e: + error_msg = f"PDF文件解析失败: {str(e)}" + logger.warning(f"[UploadForward] {error_msg}") + return None, error_msg + + +def _extract_docx_text(file_bytes: bytes) -> tuple[str | None, str]: + """尝试提取 Word 文档文本内容 + + Returns: + tuple: (提取的文本或None, 错误信息) + """ + try: + import docx + import io + except ImportError: + return None, "python-docx库未安装,无法解析Word文档" + + try: + doc = docx.Document(io.BytesIO(file_bytes)) + paragraphs = [] + for para in doc.paragraphs: + if para.text.strip(): + paragraphs.append(para.text.strip()) + + for table in doc.tables: + for row in table.rows: + row_text = " | ".join(cell.text.strip() for cell in row.cells if cell.text.strip()) + if row_text: + paragraphs.append(row_text) + + text = "\n".join(paragraphs) + if not text: + return None, "Word文档中没有可提取的文本内容" + return text, "" + except Exception as e: + error_msg = f"Word文档解析失败: {str(e)}" + logger.warning(f"[UploadForward] {error_msg}") + return None, error_msg + + +@router.post("/forward") +async def forward_file(file: UploadFile = File(...)): + try: + file_bytes = await file.read() + except Exception as e: + logger.warning(f"[UploadForward] 文件读取失败: {e}") + return JSONResponse( + status_code=400, + content={"status": "error", "message": "文件读取失败,请重新上传"}, + ) + + if len(file_bytes) > settings.FILE_MAX_SIZE: + return JSONResponse( + status_code=400, + content={"status": "error", "message": f"文件大小超过限制 ({settings.FILE_MAX_SIZE // 1024 // 1024}MB)"}, + ) + + try: + filename = file.filename or "unknown" + content_type = file.content_type or "" + ext = filename.split('.')[-1].lower() if '.' in filename else '' + + if not ext and content_type: + ext_map = { + "image/jpeg": "jpg", + "image/png": "png", + "image/gif": "gif", + "image/bmp": "bmp", + "image/webp": "webp", + "application/pdf": "pdf", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document": "docx", + } + ext = ext_map.get(content_type, "") + + image_extensions = {'jpg', 'jpeg', 'png', 'gif', 'bmp', 'webp'} + if ext in image_extensions: + b64 = base64.b64encode(file_bytes).decode('utf-8') + mime = f"image/{ext if ext != 'jpg' else 'jpeg'}" + content = f"data:{mime};base64,{b64}" + return { + "status": "success", + "content": content, + "type": "image", + "filename": filename, + } + + if ext == 'pdf': + text, error = _extract_pdf_text(file_bytes) + if text: + return { + "status": "success", + "content": text, + "type": "text", + "filename": filename, + } + return JSONResponse( + status_code=400, + content={"status": "error", "message": error or "无法提取PDF文本内容"}, + ) + + if ext in ('docx', 'doc'): + text, error = _extract_docx_text(file_bytes) + if text: + return { + "status": "success", + "content": text, + "type": "text", + "filename": filename, + } + return JSONResponse( + status_code=400, + content={"status": "error", "message": error or "无法提取Word文档文本内容"}, + ) + + text_extensions = {'txt', 'md', 'csv', 'json', 'xml', 'html', 'css', 'js', 'py', 'java', 'cpp', 'c', 'h', 'go', 'rs', 'ts', 'sql', 'yaml', 'yml'} + if ext in text_extensions: + try: + text = file_bytes.decode('utf-8') + except UnicodeDecodeError: + try: + text = file_bytes.decode('gbk') + except UnicodeDecodeError: + text = file_bytes.decode('utf-8', errors='ignore') + return { + "status": "success", + "content": text, + "type": "text", + "filename": filename, + } + + return JSONResponse( + status_code=400, + content={"status": "error", "message": f"不支持的文件类型: .{ext}"}, + ) + + except Exception as e: + logger.warning(f"[UploadForward] 文件处理失败: {e}") + return JSONResponse( + status_code=500, + content={"status": "error", "message": "文件上传处理失败,请稍后重试"}, + ) diff --git a/backend/app/api/v1/__pycache__/__init__.cpython-313.pyc b/backend/app/api/v1/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index 84ffc7a..0000000 Binary files a/backend/app/api/v1/__pycache__/__init__.cpython-313.pyc and /dev/null differ diff --git a/backend/app/api/v1/endpoints/avatar.py b/backend/app/api/v1/endpoints/avatar.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/api/v1/endpoints/chat.py b/backend/app/api/v1/endpoints/chat.py index 2d04d78..228d707 100644 --- a/backend/app/api/v1/endpoints/chat.py +++ b/backend/app/api/v1/endpoints/chat.py @@ -3,6 +3,7 @@ import time import asyncio from datetime import datetime, timezone +from collections.abc import AsyncIterator from fastapi import APIRouter from fastapi.responses import StreamingResponse from loguru import logger @@ -18,6 +19,11 @@ from app.runtime.provider.llm.adapter import llm_adapter from app.infrastructure.database.json_store import conversations_store, agents_store from app.core.config import settings +from app.utils.intent_gateway import classify_request, RequestType +from app.utils.tool_lazy_loader import get_matched_tools +from app.utils.tool_result_processor import process_tool_result +from app.utils.local_handler import handle_local_tool_request +from app.utils.tool_executor import execute_tool_chain, build_tool_summary, execute_single_tool router = APIRouter(prefix="/chat", tags=["chat"]) @@ -39,6 +45,94 @@ def _get_user_query(messages: list[dict]) -> str: return "" +def _resolve_tools(user_message: str, request_type: RequestType) -> list[dict] | None: + """按需解析工具定义 —— 仅 TOOL_CALL 类型才注入匹配场景的工具 + + GENERAL_CHAT 和 LOCAL_TOOL 请求绝不注入任何工具,从根源杜绝工具乱触发。 + + 异常安全: + 懒加载异常时返回空列表 [](不注入任何工具),避免全量注入导致工具乱触发。 + GENERAL_CHAT 和 LOCAL_TOOL 请求始终返回 None。 + + 参数: + user_message: 用户原始消息文本 + request_type: classify_request 返回的请求类型 + + 返回: + - TOOL_CALL 且命中场景:OpenAI Function Calling 格式工具列表 + - TOOL_CALL 但无匹配场景:空列表 [](等效不注入工具) + - 其他类型:None(不注入工具) + - 异常:空列表 [](安全降级,不注入工具) + """ + if request_type != RequestType.TOOL_CALL: + return None + + try: + tools = get_matched_tools(user_message) + return tools if tools else [] + except Exception as e: + logger.warning(f"[Chat] 工具懒加载异常,降级返回空列表(不注入工具): {e}") + return [] + + +def _inject_system_prompt(messages: list[dict]) -> list[dict]: + from datetime import datetime + now = datetime.now() + weekday_names = ['星期一', '星期二', '星期三', '星期四', '星期五', '星期六', '星期日'] + current_date = now.strftime("%Y年%m月%d日") + current_weekday = weekday_names[now.weekday()] + current_time = now.strftime("%H:%M") + date_prompt = f"当前时间:{current_date} {current_weekday} {current_time} (Asia/Shanghai)。请基于这个时间回答用户的问题。当用户问「距离XX还有几天」时,你需要先用工具查询目标日期,然后用当前时间计算差值。" + + has_system = False + for msg in messages: + if msg.get("role") == "system": + has_system = True + existing = msg.get("content", "") + if "当前时间" not in existing: + msg["content"] = date_prompt + "\n\n" + existing + break + + if not has_system: + messages = [{"role": "system", "content": date_prompt}] + messages + + return messages + + +def _inject_file_content(messages: list[dict], parsed_content: str, file_type: str = "text") -> list[dict]: + if not parsed_content or not parsed_content.strip(): + return messages + + # 根据文件类型判断是否是图片 + is_image = file_type == "image" or parsed_content.startswith("data:image") + + if is_image: + # 找到最后一条用户消息,将图片内容附加到该消息 + for i in range(len(messages) - 1, -1, -1): + if messages[i]["role"] == "user": + # 提取文字内容和图片 + text_content = messages[i]["content"] + # 移除 [图片附件] 标记后的内容 + if "[图片附件]" in text_content: + text_content = text_content.split("[图片附件]")[0].strip() + + # 构建多模态消息格式 + messages[i]["content"] = [ + {"type": "text", "text": text_content or "请分析这张图片"}, + {"type": "image_url", "image_url": {"url": parsed_content}}, + ] + return messages + return messages + + # 普通文本内容 + context_text = ( + "[用户上传文件内容] 以下是与当前对话相关的文件内容,请参考这些内容回答用户的问题。" + "如果用户的问题与文件内容无关,请正常回答用户问题,不需要强行关联文件。\n\n" + + parsed_content + ) + return [{"role": "user", "content": context_text}] + messages + + async def _inject_memory(messages: list[dict], agent_id: str | None = None, provider_name: str | None = None) -> list[dict]: try: from app.engines.memory.core import MemoryInjector, get_memory_storage @@ -98,6 +192,35 @@ async def _do_update(): logger.warning(f"[Memory] Update scheduling skipped: {e}") +def _persist_conv(conv_id: str, conv: dict) -> None: + conv["updated_at"] = datetime.now(timezone.utc).isoformat() + conversations_store.set(conv_id, conv) + + +def _append_user_msg(conv: dict, content: str, file_content: str | None = None) -> dict: + entry: dict = {"role": "user", "content": content} + if file_content: + entry["file_content"] = file_content + last = conv["messages"][-1] if conv["messages"] else None + if not last or last != entry: + conv["messages"].append(entry) + return entry + return last + + +def _append_assistant_msg(conv: dict, content: str, reasoning: str | None = None, interrupted: bool = False) -> dict: + entry: dict = {"role": "assistant", "content": content} + if reasoning: + entry["reasoning_content"] = reasoning + if interrupted: + entry["interrupted"] = True + last = conv["messages"][-1] if conv["messages"] else None + if not last or last.get("content") != content or (reasoning and last.get("reasoning_content") != reasoning): + conv["messages"].append(entry) + return entry + return last + + @router.post("/completions") async def chat_completions(request: ChatRequest): start_time = time.time() @@ -106,12 +229,172 @@ async def chat_completions(request: ChatRequest): logger.info(f"[API] POST /chat/completions - provider={resolved_provider}, model={resolved_model}, stream={request.stream}") messages = [{"role": m.role, "content": m.content} for m in request.messages] + messages = _inject_system_prompt(messages) messages = await _inject_memory(messages, request.agent_id, resolved_provider) + # 意图分类 + 按需工具加载(仅 TOOL_CALL 类型注入匹配场景的工具) + user_query = _get_user_query(messages) + request_type = classify_request(user_query) + tools = _resolve_tools(user_query, request_type) + tools_count = len(tools) if tools else 0 + logger.info(f"[API] 意图分类: type={request_type.value}, tools_injected={tools_count}") + + # LOCAL_TOOL:本地工具直接处理,不走 LLM + if request_type == RequestType.LOCAL_TOOL: + result = await handle_local_tool_request(user_query) + if result is None: + raw = await llm_adapter.chat( + messages=messages, + provider_name=resolved_provider, + model=resolved_model, + temperature=request.temperature, + max_tokens=request.max_tokens, + top_p=request.top_p, + ) + result_content = raw.get("content") if isinstance(raw, dict) else raw + result_reasoning = raw.get("reasoning") if isinstance(raw, dict) else None + else: + result_content = result.get("content") if isinstance(result, dict) else result + result_reasoning = result.get("reasoning") if isinstance(result, dict) else None + elapsed = time.time() - start_time + logger.success(f"[API] POST /chat/completions [LOCAL_TOOL] - Success: elapsed={elapsed:.2f}s") + + if request.stream: + chat_id = str(uuid.uuid4()) + data = ChatStreamChunk(id=chat_id, content=result_content, reasoning_content=result_reasoning or "", model=resolved_model, provider=resolved_provider) + done_data = ChatStreamChunk(id=chat_id, content="", model=resolved_model, provider=resolved_provider, done=True) + + async def _local_tool_stream(): + yield f"data: {data.model_dump_json()}\n\n" + yield f"data: {done_data.model_dump_json()}\n\n" + + return StreamingResponse(_local_tool_stream(), media_type="text/event-stream") + + return ChatResponse( + id=str(uuid.uuid4()), + content=result_content, + model=resolved_model, + provider=resolved_provider, + ) + + # TOOL_CALL:先走本地工具快速路径,未匹配则 Tool Loop / 规则驱动 + if request_type == RequestType.TOOL_CALL: + local_result = await handle_local_tool_request(user_query) + if local_result is not None: + local_result_content = local_result.get("content") if isinstance(local_result, dict) else local_result + local_result_reasoning = local_result.get("reasoning") if isinstance(local_result, dict) else None + elapsed = time.time() - start_time + logger.success(f"[API] POST /chat/completions [TOOL local] - Success: elapsed={elapsed:.2f}s") + + if request.stream: + chat_id = str(uuid.uuid4()) + data = ChatStreamChunk(id=chat_id, content=local_result_content, reasoning_content=local_result_reasoning or "", model=resolved_model, provider=resolved_provider) + done_data = ChatStreamChunk(id=chat_id, content="", model=resolved_model, provider=resolved_provider, done=True) + + async def _tool_local_stream(): + yield f"data: {data.model_dump_json()}\n\n" + yield f"data: {done_data.model_dump_json()}\n\n" + + return StreamingResponse(_tool_local_stream(), media_type="text/event-stream") + + return ChatResponse( + id=str(uuid.uuid4()), + content=local_result_content, + model=resolved_model, + provider=resolved_provider, + ) + + fc_supported = llm_adapter.supports_tool_calls(resolved_provider, resolved_model) + if fc_supported: + from app.core.agent.tool_loop import tool_loop, tool_loop_stream, get_all_tools_schema + fc_tools = get_all_tools_schema() + if fc_tools: + loop_kwargs = {} + if request.temperature is not None: loop_kwargs["temperature"] = request.temperature or 0.7 + if request.max_tokens is not None: loop_kwargs["max_tokens"] = request.max_tokens or 4096 + if request.top_p is not None: loop_kwargs["top_p"] = request.top_p or 0.9 + + if request.stream: + async def _tool_loop_stream(): + chat_id = str(uuid.uuid4()) + async for event in tool_loop_stream( + messages=messages, tools=fc_tools, + provider_name=resolved_provider, model=resolved_model, **loop_kwargs, + ): + etype = event.get("type") + if etype == "content": + yield f"data: {ChatStreamChunk(id=chat_id, content=event.get('content', ''), model=resolved_model, provider=resolved_provider).model_dump_json()}\n\n" + elif etype == "reasoning": + yield f"data: {ChatStreamChunk(id=chat_id, content='', reasoning_content=event.get('content', ''), model=resolved_model, provider=resolved_provider).model_dump_json()}\n\n" + elif etype == "done": + c = event.get("content", "") + if c: + yield f"data: {ChatStreamChunk(id=chat_id, content=c, model=resolved_model, provider=resolved_provider).model_dump_json()}\n\n" + yield f"data: {ChatStreamChunk(id=chat_id, content='', model=resolved_model, provider=resolved_provider, done=True).model_dump_json()}\n\n" + + elapsed = time.time() - start_time + logger.success(f"[API] POST /chat/completions [TOOL loop stream] - elapsed={elapsed:.2f}s") + return StreamingResponse(_tool_loop_stream(), media_type="text/event-stream") + + result = await tool_loop( + messages=messages, tools=fc_tools, + provider_name=resolved_provider, model=resolved_model, **loop_kwargs, + ) + elapsed = time.time() - start_time + logger.success(f"[API] POST /chat/completions [TOOL loop] - elapsed={elapsed:.2f}s") + return ChatResponse( + id=str(uuid.uuid4()), + content=result.get("content", ""), + model=resolved_model, + provider=resolved_provider, + ) + + tool_results = await execute_tool_chain( + user_query, + agent_id=request.agent_id, + external_search_results=request.search_results, + ) + if tool_results: + summary_prompt = build_tool_summary(user_query, tool_results) + summary_messages = [{"role": "user", "content": summary_prompt}] + + if request.stream: + async def _tool_chain_stream(): + chat_id = str(uuid.uuid4()) + yield f"data: {ChatStreamChunk(id=chat_id, content='', reasoning_content='正在查询所需信息…', model=resolved_model, provider=resolved_provider).model_dump_json()}\n\n" + async for chunk in llm_adapter.chat_stream( + messages=summary_messages, provider_name=resolved_provider, model=resolved_model, + temperature=request.temperature or 0.7, + max_tokens=request.max_tokens or 4096, top_p=request.top_p or 0.9, + ): + yield f"data: {ChatStreamChunk(id=chat_id, content=chunk.get('content', ''), model=resolved_model, provider=resolved_provider).model_dump_json()}\n\n" + yield f"data: {ChatStreamChunk(id=chat_id, content='', model=resolved_model, provider=resolved_provider, done=True).model_dump_json()}\n\n" + + elapsed = time.time() - start_time + logger.success(f"[API] POST /chat/completions [TOOL chain stream] - Success: elapsed={elapsed:.2f}s") + return StreamingResponse(_tool_chain_stream(), media_type="text/event-stream") + + raw = await llm_adapter.chat( + messages=summary_messages, provider_name=resolved_provider, model=resolved_model, + temperature=request.temperature or 0.7, + max_tokens=request.max_tokens or 4096, top_p=request.top_p or 0.9, + ) + result = raw.get("content") if isinstance(raw, dict) else raw + elapsed = time.time() - start_time + logger.success(f"[API] POST /chat/completions [TOOL chain] - Success: elapsed={elapsed:.2f}s") + return ChatResponse( + id=str(uuid.uuid4()), + content=result, + model=resolved_model, + provider=resolved_provider, + ) + + # 无匹配工具 → 降级到通用对话 + if request.stream: logger.info(f"[API] POST /chat/completions - Starting stream response") return StreamingResponse( - _stream_chat(messages, request, resolved_provider, resolved_model), + _stream_chat(messages, request, resolved_provider, resolved_model, tools), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", @@ -130,11 +413,12 @@ async def chat_completions(request: ChatRequest): top_p=request.top_p, ) + result_content = result.get("content") if isinstance(result, dict) else result elapsed = time.time() - start_time - logger.success(f"[API] POST /chat/completions - Success: elapsed={elapsed:.2f}s, response_len={len(result)}") + logger.success(f"[API] POST /chat/completions - Success: elapsed={elapsed:.2f}s, response_len={len(result_content)}") return ChatResponse( id=str(uuid.uuid4()), - content=result, + content=result_content, model=resolved_model, provider=resolved_provider, ) @@ -144,53 +428,26 @@ async def chat_completions(request: ChatRequest): raise -async def _stream_chat(messages: list[dict], request: ChatRequest, provider: str, model: str): - start_time = time.time() +async def _stream_chat(messages: list[dict], request: ChatRequest, provider: str, model: str, tools: list[dict] | None = None): chat_id = str(uuid.uuid4()) - chunk_count = 0 - - logger.info(f"[STREAM] Starting stream: chat_id={chat_id}, provider={provider}, model={model}") - - try: - async for chunk in llm_adapter.chat_stream( - messages=messages, - provider_name=provider, - model=model, - temperature=request.temperature, - max_tokens=request.max_tokens, - top_p=request.top_p, - ): - chunk_count += 1 - data = ChatStreamChunk( - id=chat_id, - content=chunk, - model=model, - provider=provider, - ) - yield f"data: {data.model_dump_json()}\n\n" - - done_data = ChatStreamChunk( - id=chat_id, - content="", - model=model, - provider=provider, - done=True, - ) - yield f"data: {done_data.model_dump_json()}\n\n" + full_reply = "" + async for chunk in llm_adapter.chat_stream( + messages=messages, + provider_name=provider, + model=model, + temperature=request.temperature, + max_tokens=request.max_tokens, + top_p=request.top_p, + ): + content = chunk.get("content", "") + rc = chunk.get("reasoning", "") + if content: + full_reply += content + data = ChatStreamChunk(id=chat_id, content=content, reasoning_content=rc, model=model, provider=provider) + yield f"data: {data.model_dump_json()}\n\n" - elapsed = time.time() - start_time - logger.success(f"[STREAM] Stream completed: chat_id={chat_id}, chunks={chunk_count}, elapsed={elapsed:.2f}s") - except Exception as e: - elapsed = time.time() - start_time - logger.error(f"[STREAM] Stream failed: chat_id={chat_id}, elapsed={elapsed:.2f}s, error={e}") - error_data = ChatStreamChunk( - id=chat_id, - content=f"[Error] {str(e)}", - model=model, - provider=provider, - done=True, - ) - yield f"data: {error_data.model_dump_json()}\n\n" + done_data = ChatStreamChunk(id=chat_id, content="", model=model, provider=provider, done=True) + yield f"data: {done_data.model_dump_json()}\n\n" @router.get("/conversations", response_model=list[ConversationListResponse]) @@ -266,132 +523,404 @@ async def delete_conversation(conv_id: str): async def add_message(conv_id: str, request: ChatRequest): start_time = time.time() logger.info(f"[API] POST /chat/conversations/{conv_id}/messages - Adding message") + conv = conversations_store.get(conv_id) if not conv: - logger.error(f"[API] POST /chat/conversations/{conv_id}/messages - Conversation not found") from app.core.exceptions import NotFoundError raise NotFoundError(f"Conversation {conv_id} not found") - last_user_msg = None + last_user_content = "" for m in reversed(request.messages): if m.role == "user": - last_user_msg = m + last_user_content = m.content break - if last_user_msg: - msg_entry = {"role": "user", "content": last_user_msg.content} - if not conv["messages"] or conv["messages"][-1] != msg_entry: - conv["messages"].append(msg_entry) - logger.debug(f"[API] POST /chat/conversations/{conv_id}/messages - Added user message") - - conv["updated_at"] = datetime.now(timezone.utc).isoformat() + _phase_1_save_user_msg(conv, last_user_content, request.file_content, request.file_name, request.file_type) + _persist_conv(conv_id, conv) resolved_provider = request.provider or conv.get("provider") or llm_adapter.default_provider resolved_model = request.model or conv.get("model") or llm_adapter.get_provider(resolved_provider).default_model all_messages = [] for m in conv["messages"]: - all_messages.append({"role": m["role"], "content": m["content"]}) + msg = {"role": m["role"], "content": m["content"]} + # 如果 content 是列表(多模态格式),保留原样 + if isinstance(m.get("content"), list): + msg["content"] = m["content"] + all_messages.append(msg) + all_messages = _inject_system_prompt(all_messages) agent_id = request.agent_id or conv.get("agent_id") all_messages = await _inject_memory(all_messages, agent_id, resolved_provider) - if request.stream: - logger.info(f"[API] POST /chat/conversations/{conv_id}/messages - Starting stream response") - - async def stream_with_save(): - final_answer = "" - chat_id = str(uuid.uuid4()) - chunk_count = 0 - try: - async for chunk in llm_adapter.chat_stream( - messages=all_messages, - provider_name=resolved_provider, - model=resolved_model, - temperature=request.temperature, - max_tokens=request.max_tokens, - top_p=request.top_p, - ): - final_answer += chunk - chunk_count += 1 - data = ChatStreamChunk( - id=chat_id, - content=chunk, - model=resolved_model, - provider=resolved_provider, - ) - yield f"data: {data.model_dump_json()}\n\n" - - done_data = ChatStreamChunk( - id=chat_id, - content="", - model=resolved_model, - provider=resolved_provider, - done=True, - ) - yield f"data: {done_data.model_dump_json()}\n\n" - - assistant_msg = {"role": "assistant", "content": final_answer} - if not conv["messages"] or conv["messages"][-1] != assistant_msg: - conv["messages"].append(assistant_msg) - conv["updated_at"] = datetime.now(timezone.utc).isoformat() - conversations_store.set(conv_id, conv) - - _schedule_memory_update( - conv["messages"], conv_id, agent_id, - provider_name=resolved_provider, - ) - - elapsed = time.time() - start_time - logger.success(f"[STREAM] Stream completed & saved: conv={conv_id}, chunks={chunk_count}, elapsed={elapsed:.2f}s") - except Exception as e: - elapsed = time.time() - start_time - logger.error(f"[STREAM] Stream failed: conv={conv_id}, elapsed={elapsed:.2f}s, error={e}") - error_data = ChatStreamChunk( - id=chat_id, - content=f"[Error] {str(e)}", - model=resolved_model, - provider=resolved_provider, - done=True, - ) - yield f"data: {error_data.model_dump_json()}\n\n" - - return StreamingResponse( - stream_with_save(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) + if request.file_content: + logger.info(f"[API] 文件内容注入: file_type={request.file_type}, content_length={len(request.file_content)}, is_image={request.file_type == 'image'}") + all_messages = _inject_file_content(all_messages, request.file_content, request.file_type or "text") - result = await llm_adapter.chat( - messages=all_messages, - provider_name=resolved_provider, - model=resolved_model, - temperature=request.temperature, - max_tokens=request.max_tokens, - top_p=request.top_p, - ) + user_query = _get_user_query(all_messages) + request_type = classify_request(user_query) + tools = _resolve_tools(user_query, request_type) + logger.info(f"[API] Intent={request_type.value}, tools={len(tools) if tools else 0}") - assistant_msg = {"role": "assistant", "content": result} - if not conv["messages"] or conv["messages"][-1] != assistant_msg: - conv["messages"].append(assistant_msg) - conv["updated_at"] = datetime.now(timezone.utc).isoformat() - conversations_store.set(conv_id, conv) + gen_state: dict = { + "content": "", + "reasoning": "", + "aborted": False, + "started": True, + } - _schedule_memory_update( - conv["messages"], conv_id, agent_id, - provider_name=resolved_provider, - ) + if request.stream: + return await _STREAM_RESPONSE(conv_id, conv, request, all_messages, user_query, + request_type, tools, resolved_provider, resolved_model, + agent_id, gen_state, start_time, + search_results=request.search_results) + + await _NON_STREAM_GENERATE(gen_state, request_type, user_query, all_messages, + resolved_provider, resolved_model, tools, agent_id, + temperature=request.temperature, + max_tokens=request.max_tokens, + top_p=request.top_p, + search_results=request.search_results) + + _PHASE_3_SAVE_ASSISTANT_MSG(conv, gen_state) + _persist_conv(conv_id, conv) + _schedule_memory_update(conv["messages"], conv_id, agent_id, provider_name=resolved_provider) elapsed = time.time() - start_time - logger.success(f"[API] POST /chat/conversations/{conv_id}/messages - Success: elapsed={elapsed:.2f}s, response_len={len(result)}") + logger.success(f"[API] Done: conv={conv_id}, elapsed={elapsed:.2f}s, len={len(gen_state['content'])}, aborted={gen_state['aborted']}") return ChatResponse( id=str(uuid.uuid4()), - content=result, + content=gen_state["content"], model=resolved_model, provider=resolved_provider, ) + + +def _phase_1_save_user_msg(conv: dict, content: str, file_content: str | None = None, file_name: str | None = None, file_type: str | None = None) -> None: + if not content: + return + entry: dict = {"role": "user", "content": content} + if file_content: + entry["file_content"] = file_content + if file_name: + entry["file_name"] = file_name + if file_type: + entry["file_type"] = file_type + if file_content and file_name: + entry["files"] = [{"name": file_name, "type": file_type, "content": file_content}] + last = conv["messages"][-1] if conv["messages"] else None + if not last or last != entry: + conv["messages"].append(entry) + + +def _PHASE_3_SAVE_ASSISTANT_MSG(conv: dict, state: dict) -> None: + content = state["content"] or "[已中断]" + reasoning = state["reasoning"] or None + interrupted = state["aborted"] + entry: dict = {"role": "assistant", "content": content} + if reasoning: + entry["reasoning_content"] = reasoning + if interrupted: + entry["interrupted"] = True + last = conv["messages"][-1] if conv["messages"] else None + if not last or last.get("content") != content: + conv["messages"].append(entry) + + +async def _NON_STREAM_GENERATE(state: dict, request_type: RequestType, + user_query: str, all_messages: list[dict], + provider: str, model: str, tools: list | None, + agent_id: str | None = None, + temperature: float | None = None, + max_tokens: int | None = None, + top_p: float | None = None, + search_results: str | None = None) -> None: + try: + if request_type == RequestType.LOCAL_TOOL: + result = await handle_local_tool_request(user_query) + if result is None: + raw = await llm_adapter.chat(messages=all_messages, provider_name=provider, + model=model, temperature=temperature, max_tokens=max_tokens, top_p=top_p) + result = raw.get("content") if isinstance(raw, dict) else raw + if isinstance(raw, dict) and raw.get("reasoning"): + state["reasoning"] = raw["reasoning"] + else: + result_content = result.get("content") if isinstance(result, dict) else result + if isinstance(result, dict) and result.get("reasoning"): + state["reasoning"] = result["reasoning"] + result = result_content + state["content"] = result or "" + + elif request_type == RequestType.TOOL_CALL: + local_result = await handle_local_tool_request(user_query) + if local_result is not None: + local_content = local_result.get("content") if isinstance(local_result, dict) else local_result + state["content"] = local_content + else: + fc_supported = llm_adapter.supports_tool_calls(provider, model) + if fc_supported: + from app.core.agent.tool_loop import tool_loop, get_all_tools_schema + fc_tools = get_all_tools_schema() + if fc_tools: + loop_kwargs = {} + if temperature is not None: loop_kwargs["temperature"] = temperature + if max_tokens is not None: loop_kwargs["max_tokens"] = max_tokens + if top_p is not None: loop_kwargs["top_p"] = top_p + result = await tool_loop( + messages=all_messages, tools=fc_tools, + provider_name=provider, model=model, **loop_kwargs, + ) + state["content"] = result.get("content", "") + state["reasoning"] = result.get("reasoning", "") + else: + raw = await llm_adapter.chat(messages=all_messages, provider_name=provider, + model=model, temperature=temperature, max_tokens=max_tokens, top_p=top_p) + state["content"] = raw.get("content", "") if isinstance(raw, dict) else raw + else: + tool_results = await execute_tool_chain(user_query, agent_id=agent_id, + external_search_results=search_results) + if tool_results: + summary_prompt = build_tool_summary(user_query, tool_results) + summary_messages = [{"role": "user", "content": summary_prompt}] + raw = await llm_adapter.chat(messages=summary_messages, provider_name=provider, + model=model, temperature=temperature, max_tokens=max_tokens, top_p=top_p) + state["content"] = raw.get("content", "") if isinstance(raw, dict) else raw + else: + raw = await llm_adapter.chat(messages=all_messages, provider_name=provider, + model=model, temperature=temperature, max_tokens=max_tokens, top_p=top_p) + state["content"] = raw.get("content", "") if isinstance(raw, dict) else raw + + else: + fc_supported = llm_adapter.supports_tool_calls(provider, model) + if fc_supported: + from app.core.agent.tool_loop import tool_loop, get_all_tools_schema + fc_tools = get_all_tools_schema() + if fc_tools: + loop_kwargs = {} + if temperature is not None: loop_kwargs["temperature"] = temperature + if max_tokens is not None: loop_kwargs["max_tokens"] = max_tokens + if top_p is not None: loop_kwargs["top_p"] = top_p + result = await tool_loop( + messages=all_messages, tools=fc_tools, + provider_name=provider, model=model, **loop_kwargs, + ) + state["content"] = result.get("content", "") + state["reasoning"] = result.get("reasoning", "") + else: + raw = await llm_adapter.chat(messages=all_messages, provider_name=provider, + model=model, temperature=temperature, max_tokens=max_tokens, top_p=top_p) + if isinstance(raw, dict): + state["content"] = raw.get("content", "") + if raw.get("reasoning"): + state["reasoning"] = raw["reasoning"] + else: + state["content"] = raw + else: + raw = await llm_adapter.chat(messages=all_messages, provider_name=provider, + model=model, temperature=temperature, max_tokens=max_tokens, top_p=top_p) + if isinstance(raw, dict): + state["content"] = raw.get("content", "") + if raw.get("reasoning"): + state["reasoning"] = raw["reasoning"] + else: + state["content"] = raw + except Exception as e: + logger.error(f"[API] Non-stream error: {e}") + state["aborted"] = True + state["content"] = f"[Error] {str(e)}" + + +async def _STREAM_RESPONSE(conv_id: str, conv: dict, request: ChatRequest, + all_messages: list, user_query: str, request_type: RequestType, + tools: list | None, provider: str, model: str, + agent_id: str | None, state: dict, start_time: float, + search_results: str | None = None): + chat_id = str(uuid.uuid4()) + + async def generator(): + try: + if request_type == RequestType.LOCAL_TOOL: + result = await handle_local_tool_request(user_query) + if result is None: + raw = await llm_adapter.chat(messages=all_messages, provider_name=provider, + model=model, temperature=request.temperature or 0.7, + max_tokens=request.max_tokens or 4096, top_p=request.top_p or 0.9) + result_content = raw.get("content") if isinstance(raw, dict) else raw + if isinstance(raw, dict) and raw.get("reasoning"): + state["reasoning"] = raw["reasoning"] + else: + result_content = result.get("content") if isinstance(result, dict) else result + if isinstance(result, dict) and result.get("reasoning"): + state["reasoning"] = result["reasoning"] + state["content"] = result_content or "" + yield _sse(chat_id, state["content"], provider, model) + + elif request_type == RequestType.TOOL_CALL: + local_result = await handle_local_tool_request(user_query) + if local_result is not None: + local_content = local_result.get("content") if isinstance(local_result, dict) else local_result + if isinstance(local_result, dict) and local_result.get("reasoning"): + state["reasoning"] = local_result["reasoning"] + state["content"] = local_content or "" + yield _sse(chat_id, state["content"], provider, model) + else: + fc_supported = llm_adapter.supports_tool_calls(provider, model) + if fc_supported: + from app.core.agent.tool_loop import tool_loop_stream, get_all_tools_schema + fc_tools = get_all_tools_schema() + if fc_tools: + loop_kwargs = {} + if request.temperature is not None: loop_kwargs["temperature"] = request.temperature or 0.7 + if request.max_tokens is not None: loop_kwargs["max_tokens"] = request.max_tokens or 4096 + if request.top_p is not None: loop_kwargs["top_p"] = request.top_p or 0.9 + async for event in tool_loop_stream( + messages=all_messages, tools=fc_tools, + provider_name=provider, model=model, **loop_kwargs, + ): + etype = event.get("type") + if etype == "content": + c = event.get("content", "") + state["content"] += c + yield _sse(chat_id, c, provider, model) + elif etype == "reasoning": + rc = event.get("content", "") + state["reasoning"] += rc + yield _sse(chat_id, "", provider, model, rc) + elif etype == "done": + c = event.get("content", "") + if c: + state["content"] += c + yield _sse(chat_id, c, provider, model) + rc = event.get("reasoning", "") + if rc: + state["reasoning"] += rc + else: + async for chunk in llm_adapter.chat_stream( + messages=all_messages, provider_name=provider, model=model, + temperature=request.temperature or 0.7, + max_tokens=request.max_tokens or 4096, top_p=request.top_p or 0.9, + ): + content = chunk.get("content", "") + rc = chunk.get("reasoning", "") + if content: + state["content"] += content + if rc: + state["reasoning"] += rc + yield _sse(chat_id, content, provider, model, rc) + else: + yield _sse_reasoning(chat_id, "正在查询所需信息…", provider, model) + tool_results = await execute_tool_chain(user_query, agent_id=agent_id, + external_search_results=search_results) + if tool_results: + summary_prompt = build_tool_summary(user_query, tool_results) + summary_messages = [{"role": "user", "content": summary_prompt}] + async for chunk in llm_adapter.chat_stream( + messages=summary_messages, provider_name=provider, model=model, + temperature=request.temperature or 0.7, + max_tokens=request.max_tokens or 4096, top_p=request.top_p or 0.9, + ): + content = chunk.get("content", "") + rc = chunk.get("reasoning", "") + if content: + state["content"] += content + if rc: + state["reasoning"] += rc + yield _sse(chat_id, content, provider, model, rc) + else: + async for chunk in llm_adapter.chat_stream( + messages=all_messages, provider_name=provider, model=model, + temperature=request.temperature or 0.7, + max_tokens=request.max_tokens or 4096, top_p=request.top_p or 0.9, + ): + content = chunk.get("content", "") + rc = chunk.get("reasoning", "") + if content: + state["content"] += content + if rc: + state["reasoning"] += rc + yield _sse(chat_id, content, provider, model, rc) + + else: + fc_supported = llm_adapter.supports_tool_calls(provider, model) + if fc_supported: + from app.core.agent.tool_loop import tool_loop_stream, get_all_tools_schema + fc_tools = get_all_tools_schema() + if fc_tools: + loop_kwargs = {} + if request.temperature is not None: loop_kwargs["temperature"] = request.temperature or 0.7 + if request.max_tokens is not None: loop_kwargs["max_tokens"] = request.max_tokens or 4096 + if request.top_p is not None: loop_kwargs["top_p"] = request.top_p or 0.9 + async for event in tool_loop_stream( + messages=all_messages, tools=fc_tools, + provider_name=provider, model=model, **loop_kwargs, + ): + etype = event.get("type") + if etype == "content": + c = event.get("content", "") + state["content"] += c + yield _sse(chat_id, c, provider, model) + elif etype == "reasoning": + rc = event.get("content", "") + state["reasoning"] += rc + yield _sse(chat_id, "", provider, model, rc) + elif etype == "done": + c = event.get("content", "") + if c: + state["content"] += c + yield _sse(chat_id, c, provider, model) + rc = event.get("reasoning", "") + if rc: + state["reasoning"] += rc + else: + async for chunk in llm_adapter.chat_stream( + messages=all_messages, provider_name=provider, model=model, + temperature=request.temperature or 0.7, + max_tokens=request.max_tokens or 4096, top_p=request.top_p or 0.9, + ): + state["content"] += chunk.get("content", "") + rc = chunk.get("reasoning", "") + if rc: + state["reasoning"] += rc + yield _sse(chat_id, chunk.get("content", ""), provider, model, rc) + else: + async for chunk in llm_adapter.chat_stream( + messages=all_messages, tools=tools, provider_name=provider, model=model, + temperature=request.temperature or 0.7, + max_tokens=request.max_tokens or 4096, top_p=request.top_p or 0.9, + ): + state["content"] += chunk.get("content", "") + rc = chunk.get("reasoning", "") + if rc: + state["reasoning"] += rc + yield _sse(chat_id, chunk.get("content", ""), provider, model, rc) + + except Exception as e: + state["aborted"] = True + state["content"] = f"[Error] {str(e)}" + logger.error(f"[STREAM] Aborted: conv={conv_id}, error={e}") + yield _sse(chat_id, state["content"], provider, model) + + finally: + _PHASE_3_SAVE_ASSISTANT_MSG(conv, state) + _persist_conv(conv_id, conv) + logger.info(f"[STREAM] Persisted: conv={conv_id}, " + f"content_len={len(state['content'])}, " + f"reasoning_len={len(state['reasoning'])}, " + f"aborted={state['aborted']}") + yield _sse_done(chat_id, provider, model) + _schedule_memory_update(conv["messages"], conv_id, agent_id, provider_name=provider) + + return StreamingResponse(generator(), media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive", + "X-Accel-Buffering": "no"}) + + +def _sse(cid: str, content: str, provider: str, model: str, reasoning: str = "") -> str: + return f"data: {ChatStreamChunk(id=cid, content=content, reasoning_content=reasoning, model=model, provider=provider).model_dump_json()}\n\n" + +def _sse_reasoning(cid: str, reasoning: str, provider: str, model: str) -> str: + return f"data: {ChatStreamChunk(id=cid, content='', reasoning_content=reasoning, model=model, provider=provider).model_dump_json()}\n\n" + +def _sse_done(cid: str, provider: str, model: str) -> str: + return f"data: {ChatStreamChunk(id=cid, content='', model=model, provider=provider, done=True).model_dump_json()}\n\n" diff --git a/backend/app/api/v1/endpoints/device.py b/backend/app/api/v1/endpoints/device.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/api/v1/endpoints/iot.py b/backend/app/api/v1/endpoints/iot.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/api/v1/endpoints/plugin.py b/backend/app/api/v1/endpoints/plugin.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/api/v1/endpoints/session.py b/backend/app/api/v1/endpoints/session.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/api/v1/endpoints/user.py b/backend/app/api/v1/endpoints/user.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/api/ws/__pycache__/__init__.cpython-313.pyc b/backend/app/api/ws/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index f036c7f..0000000 Binary files a/backend/app/api/ws/__pycache__/__init__.cpython-313.pyc and /dev/null differ diff --git a/backend/app/core/__pycache__/__init__.cpython-313.pyc b/backend/app/core/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index fd9bc2a..0000000 Binary files a/backend/app/core/__pycache__/__init__.cpython-313.pyc and /dev/null differ diff --git a/backend/app/core/agent/__init__.py b/backend/app/core/agent/__init__.py new file mode 100644 index 0000000..066851e --- /dev/null +++ b/backend/app/core/agent/__init__.py @@ -0,0 +1,227 @@ +import json +from loguru import logger +from app.runtime.provider.llm.adapter import llm_adapter +from app.utils.tool_executor import execute_tool_by_name +from app.runtime.plugin.skill.registry import SkillRegistry + + +def get_all_tools_schema() -> list[dict]: + try: + return SkillRegistry.get_openai_tools() + except Exception as e: + logger.warning(f"[ToolLoop] get_all_tools_schema failed: {e}") + return [] + + +async def tool_loop( + messages: list[dict], + tools: list[dict], + provider_name: str, + model: str, + max_steps: int = 5, + **kwargs, +) -> dict: + """Tool Loop 核心循环(非流式) + + 返回: + {"content": str, "reasoning": str, "tool_steps": int} + """ + tool_steps = 0 + duplicate_counter: dict[str, int] = {} + + for step in range(max_steps): + raw = await llm_adapter.chat( + messages=messages, + tools=tools, + provider_name=provider_name, + model=model, + return_raw=True, + **kwargs, + ) + + if not isinstance(raw, dict): + return {"content": str(raw), "reasoning": "", "tool_steps": tool_steps} + + tool_calls = raw.get("tool_calls", []) + content = raw.get("content", "") or "" + reasoning = raw.get("reasoning", "") or "" + + if not tool_calls: + return {"content": content, "reasoning": reasoning, "tool_steps": tool_steps} + + messages.append({ + "role": "assistant", + "content": content or None, + "tool_calls": tool_calls, + }) + + for tc in tool_calls: + tc_id = tc.get("id", f"call_{step}") + fn = tc.get("function", {}) + tool_name = fn.get("name", "") + args_str = fn.get("arguments", "{}") + + logger.info(f"[ToolLoop] Step {step + 1}: calling {tool_name}({args_str[:100]})") + + duplicate_counter[tool_name] = duplicate_counter.get(tool_name, 0) + 1 + if duplicate_counter[tool_name] >= 3: + logger.warning(f"[ToolLoop] {tool_name} called {duplicate_counter[tool_name]} times, injecting warning") + messages.append({ + "role": "tool", + "tool_call_id": tc_id, + "content": f"[警告] 你已经连续调用 {tool_name} {duplicate_counter[tool_name]} 次了。如果信息仍然不足,请直接根据已有信息回答用户。", + }) + continue + + try: + args = json.loads(args_str) if args_str else {} + except json.JSONDecodeError: + args = {} + + result = await execute_tool_by_name(tool_name, args) + + if len(result) > 2000: + result = result[:2000] + "...(结果已截断)" + + logger.info(f"[ToolLoop] {tool_name} → {len(result)} chars") + tool_steps += 1 + + messages.append({ + "role": "tool", + "tool_call_id": tc_id, + "content": result, + }) + + messages.append({"role": "user", "content": "请根据已获取的信息总结回答用户的问题,不要再调用工具。"}) + final = await llm_adapter.chat( + messages=messages, + provider_name=provider_name, + model=model, + **kwargs, + ) + final_content = final.get("content", "") if isinstance(final, dict) else str(final) + final_reasoning = final.get("reasoning", "") if isinstance(final, dict) else "" + return {"content": final_content, "reasoning": final_reasoning, "tool_steps": tool_steps} + + +async def tool_loop_stream( + messages: list[dict], + tools: list[dict], + provider_name: str, + model: str, + max_steps: int = 5, + **kwargs, +): + """Tool Loop 核心循环(流式) + + 每轮 yield: + - {"type": "content", "content": str} — LLM 文本输出 + - {"type": "reasoning", "content": str} — 推理/状态提示 + - {"type": "done", "content": str, "reasoning": str} — 最终结果 + """ + tool_steps = 0 + duplicate_counter: dict[str, int] = {} + + for step in range(max_steps): + collected_content = "" + collected_reasoning = "" + collected_tool_calls: dict[int, dict] = {} + + async for chunk in llm_adapter.chat_stream( + messages=messages, + tools=tools, + provider_name=provider_name, + model=model, + **kwargs, + ): + content = chunk.get("content", "") + reasoning = chunk.get("reasoning", "") + tc_complete = chunk.get("tool_calls_complete") + + if content: + collected_content += content + yield {"type": "content", "content": content} + if reasoning: + collected_reasoning += reasoning + + if tc_complete: + for tc in tc_complete: + idx = tc.get("index", len(collected_tool_calls)) + collected_tool_calls[idx] = tc + + if not collected_tool_calls: + yield { + "type": "done", + "content": collected_content, + "reasoning": collected_reasoning, + } + return + + messages.append({ + "role": "assistant", + "content": collected_content or None, + "tool_calls": [ + { + "id": v.get("id", f"call_{step}_{k}"), + "type": "function", + "function": v.get("function", {}), + } + for k, v in collected_tool_calls.items() + ], + }) + + for idx in sorted(collected_tool_calls.keys()): + tc_data = collected_tool_calls[idx] + fn = tc_data.get("function", {}) + tool_name = fn.get("name", "") + args_str = fn.get("arguments", "{}") + tc_id = tc_data.get("id", f"call_{step}_{idx}") + + logger.info(f"[ToolLoop] Step {step + 1} stream: calling {tool_name}({args_str[:100]})") + + duplicate_counter[tool_name] = duplicate_counter.get(tool_name, 0) + 1 + if duplicate_counter[tool_name] >= 3: + messages.append({ + "role": "tool", + "tool_call_id": tc_id, + "content": f"[警告] 你已经连续调用 {tool_name} {duplicate_counter[tool_name]} 次了。如果信息仍然不足,请直接根据已有信息回答用户。", + }) + continue + + yield {"type": "reasoning", "content": f"正在查询 {tool_name}..."} + + try: + args = json.loads(args_str) if args_str else {} + except json.JSONDecodeError: + args = {} + + result = await execute_tool_by_name(tool_name, args) + + if len(result) > 2000: + result = result[:2000] + "...(结果已截断)" + + logger.info(f"[ToolLoop] {tool_name} → {len(result)} chars") + tool_steps += 1 + + messages.append({ + "role": "tool", + "tool_call_id": tc_id, + "content": result, + }) + + messages.append({"role": "user", "content": "请根据已获取的信息总结回答用户的问题,不要再调用工具。"}) + async for chunk in llm_adapter.chat_stream( + messages=messages, + tools=None, + provider_name=provider_name, + model=model, + **kwargs, + ): + content = chunk.get("content", "") + rc = chunk.get("reasoning", "") + if content: + yield {"type": "content", "content": content} + if rc: + collected_reasoning += rc + + yield {"type": "done", "content": "", "reasoning": collected_reasoning} diff --git a/backend/app/core/agent/tool_loop.py b/backend/app/core/agent/tool_loop.py new file mode 100644 index 0000000..49ece34 --- /dev/null +++ b/backend/app/core/agent/tool_loop.py @@ -0,0 +1,216 @@ +import json +from loguru import logger +from app.runtime.provider.llm.adapter import llm_adapter +from app.utils.tool_executor import execute_tool_by_name +from app.runtime.plugin.skill.registry import SkillRegistry + + +def get_all_tools_schema() -> list[dict]: + try: + return SkillRegistry.get_openai_tools() + except Exception as e: + logger.warning(f"[ToolLoop] get_all_tools_schema failed: {e}") + return [] + + +async def tool_loop( + messages: list[dict], + tools: list[dict], + provider_name: str, + model: str, + max_steps: int = 5, + **kwargs, +) -> dict: + tool_steps = 0 + duplicate_counter: dict[str, int] = {} + + for step in range(max_steps): + raw = await llm_adapter.chat( + messages=messages, + tools=tools, + provider_name=provider_name, + model=model, + return_raw=True, + **kwargs, + ) + + if not isinstance(raw, dict): + return {"content": str(raw), "reasoning": "", "tool_steps": tool_steps} + + tool_calls = raw.get("tool_calls", []) + content = raw.get("content", "") or "" + reasoning = raw.get("reasoning", "") or "" + + if not tool_calls: + return {"content": content, "reasoning": reasoning, "tool_steps": tool_steps} + + messages.append({ + "role": "assistant", + "content": content or None, + "tool_calls": tool_calls, + }) + + for tc in tool_calls: + tc_id = tc.get("id", f"call_{step}") + fn = tc.get("function", {}) + tool_name = fn.get("name", "") + args_str = fn.get("arguments", "{}") + + logger.info(f"[ToolLoop] Step {step + 1}: calling {tool_name}({args_str[:100]})") + + duplicate_counter[tool_name] = duplicate_counter.get(tool_name, 0) + 1 + if duplicate_counter[tool_name] >= 3: + logger.warning(f"[ToolLoop] {tool_name} called {duplicate_counter[tool_name]} times, injecting warning") + messages.append({ + "role": "tool", + "tool_call_id": tc_id, + "content": f"[警告] 你已经连续调用 {tool_name} {duplicate_counter[tool_name]} 次了。如果信息仍然不足,请直接根据已有信息回答用户。", + }) + continue + + try: + args = json.loads(args_str) if args_str else {} + except json.JSONDecodeError: + args = {} + + result = await execute_tool_by_name(tool_name, args) + + if len(result) > 2000: + result = result[:2000] + "...(结果已截断)" + + logger.info(f"[ToolLoop] {tool_name} → {len(result)} chars") + tool_steps += 1 + + messages.append({ + "role": "tool", + "tool_call_id": tc_id, + "content": result, + }) + + messages.append({"role": "user", "content": "请根据已获取的信息总结回答用户的问题,不要再调用工具。"}) + final = await llm_adapter.chat( + messages=messages, + provider_name=provider_name, + model=model, + **kwargs, + ) + final_content = final.get("content", "") if isinstance(final, dict) else str(final) + final_reasoning = final.get("reasoning", "") if isinstance(final, dict) else "" + return {"content": final_content, "reasoning": final_reasoning, "tool_steps": tool_steps} + + +async def tool_loop_stream( + messages: list[dict], + tools: list[dict], + provider_name: str, + model: str, + max_steps: int = 5, + **kwargs, +): + tool_steps = 0 + duplicate_counter: dict[str, int] = {} + + for step in range(max_steps): + collected_content = "" + collected_reasoning = "" + collected_tool_calls: dict[int, dict] = {} + + async for chunk in llm_adapter.chat_stream( + messages=messages, + tools=tools, + provider_name=provider_name, + model=model, + **kwargs, + ): + content = chunk.get("content", "") + reasoning = chunk.get("reasoning", "") + tc_complete = chunk.get("tool_calls_complete") + + if content: + collected_content += content + yield {"type": "content", "content": content} + if reasoning: + collected_reasoning += reasoning + yield {"type": "reasoning", "content": reasoning} + + if tc_complete: + for tc in tc_complete: + idx = tc.get("index", len(collected_tool_calls)) + collected_tool_calls[idx] = tc + + if not collected_tool_calls: + yield { + "type": "done", + "content": collected_content, + "reasoning": collected_reasoning, + } + return + + messages.append({ + "role": "assistant", + "content": collected_content or None, + "tool_calls": [ + { + "id": v.get("id", f"call_{step}_{k}"), + "type": "function", + "function": v.get("function", {}), + } + for k, v in collected_tool_calls.items() + ], + }) + + for idx in sorted(collected_tool_calls.keys()): + tc_data = collected_tool_calls[idx] + fn = tc_data.get("function", {}) + tool_name = fn.get("name", "") + args_str = fn.get("arguments", "{}") + tc_id = tc_data.get("id", f"call_{step}_{idx}") + + logger.info(f"[ToolLoop] Step {step + 1} stream: calling {tool_name}({args_str[:100]})") + + duplicate_counter[tool_name] = duplicate_counter.get(tool_name, 0) + 1 + if duplicate_counter[tool_name] >= 3: + messages.append({ + "role": "tool", + "tool_call_id": tc_id, + "content": f"[警告] 你已经连续调用 {tool_name} {duplicate_counter[tool_name]} 次了。如果信息仍然不足,请直接根据已有信息回答用户。", + }) + continue + + yield {"type": "reasoning", "content": f"正在查询 {tool_name}..."} + + try: + args = json.loads(args_str) if args_str else {} + except json.JSONDecodeError: + args = {} + + result = await execute_tool_by_name(tool_name, args) + + if len(result) > 2000: + result = result[:2000] + "...(结果已截断)" + + logger.info(f"[ToolLoop] {tool_name} → {len(result)} chars") + tool_steps += 1 + + messages.append({ + "role": "tool", + "tool_call_id": tc_id, + "content": result, + }) + + messages.append({"role": "user", "content": "请根据已获取的信息总结回答用户的问题,不要再调用工具。"}) + async for chunk in llm_adapter.chat_stream( + messages=messages, + tools=None, + provider_name=provider_name, + model=model, + **kwargs, + ): + content = chunk.get("content", "") + rc = chunk.get("reasoning", "") + if content: + yield {"type": "content", "content": content} + if rc: + collected_reasoning += rc + + yield {"type": "done", "content": "", "reasoning": collected_reasoning} diff --git a/backend/app/core/app_factory.py b/backend/app/core/app_factory.py index f1acb82..cbdffb5 100644 --- a/backend/app/core/app_factory.py +++ b/backend/app/core/app_factory.py @@ -8,6 +8,7 @@ from app.core.config import settings from app.core.exceptions import LuomiNestError from app.api.v1.router import api_router +from app.api.attachment_api import router as attachment_router @asynccontextmanager @@ -89,6 +90,7 @@ async def generic_error_handler(request: Request, exc: Exception): ) app.include_router(api_router, prefix="/api/v1") + app.include_router(attachment_router, prefix="/api") @app.get("/health") async def health_check(): diff --git a/backend/app/core/config.py b/backend/app/core/config.py index cf9ba1b..abffa47 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -42,6 +42,9 @@ class Settings(BaseSettings): PLUGIN_DIR: str = "./plugins" SKILL_DIR: str = "./skills" + EXTERNAL_PARSE_API_URL: str = "" + FILE_MAX_SIZE: int = 100 * 1024 * 1024 + class Config: env_file = ".env" env_file_encoding = "utf-8" diff --git a/backend/app/domains/__init__.py b/backend/app/domains/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/domains/companion/dialogue_manager.py b/backend/app/domains/companion/dialogue_manager.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/domains/companion/persona.py b/backend/app/domains/companion/persona.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/domains/companion/storyteller.py b/backend/app/domains/companion/storyteller.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/domains/connect/device_tracker.py b/backend/app/domains/connect/device_tracker.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/domains/connect/seamless_follow.py b/backend/app/domains/connect/seamless_follow.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/domains/connect/sync_service.py b/backend/app/domains/connect/sync_service.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/domains/hwctrl/gpio_controller.py b/backend/app/domains/hwctrl/gpio_controller.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/domains/hwctrl/mcu_protocol.py b/backend/app/domains/hwctrl/mcu_protocol.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/domains/hwctrl/relay_manager.py b/backend/app/domains/hwctrl/relay_manager.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/domains/intent_classifier.py b/backend/app/domains/intent_classifier.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/domains/iot/custom_device.py b/backend/app/domains/iot/custom_device.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/domains/iot/device_hub.py b/backend/app/domains/iot/device_hub.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/domains/iot/ha_adapter.py b/backend/app/domains/iot/ha_adapter.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/domains/iot/scene_automation.py b/backend/app/domains/iot/scene_automation.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/domains/iot/xiaomi_adapter.py b/backend/app/domains/iot/xiaomi_adapter.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/domains/knowledge/profile/preference_learner.py b/backend/app/domains/knowledge/profile/preference_learner.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/domains/knowledge/profile/user_profile.py b/backend/app/domains/knowledge/profile/user_profile.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/domains/mcp_tools/client.py b/backend/app/domains/mcp_tools/client.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/domains/mcp_tools/registry.py b/backend/app/domains/mcp_tools/registry.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/domains/multimodal/image/generator.py b/backend/app/domains/multimodal/image/generator.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/domains/multimodal/vision/image_analyzer.py b/backend/app/domains/multimodal/vision/image_analyzer.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/domains/orchestrator.py b/backend/app/domains/orchestrator.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/domains/router.py b/backend/app/domains/router.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/domains/social/contact_manager.py b/backend/app/domains/social/contact_manager.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/domains/social/friend_request.py b/backend/app/domains/social/friend_request.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/domains/tool_executor.py b/backend/app/domains/tool_executor.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/infrastructure/__pycache__/__init__.cpython-313.pyc b/backend/app/infrastructure/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index d35f7aa..0000000 Binary files a/backend/app/infrastructure/__pycache__/__init__.cpython-313.pyc and /dev/null differ diff --git a/backend/app/infrastructure/storage/file_manager.py b/backend/app/infrastructure/storage/file_manager.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/mcp/__init__.py b/backend/app/mcp/__init__.py new file mode 100644 index 0000000..682119e --- /dev/null +++ b/backend/app/mcp/__init__.py @@ -0,0 +1,40 @@ +""" +LuomiNest MCP 服务包 + +提供符合 MCP(Model Context Protocol)2024-11-05 标准的工具 Server。 +每个 Server 可独立运行,通过 stdio JSON-RPC 2.0 与 MCP 客户端通信。 + +提供的 MCP Server: + - time_server:时间查询工具(get_current_time) + - weather_server:天气查询工具(get_weather_info) + +使用方式: + # 直接运行(命令行启动) + python -m app.mcp.servers.time_server + python -m app.mcp.servers.weather_server + + # Trae IDE 配置(.trae/mcp.json) + { + "mcpServers": { + "luomi-time": { + "command": "python", + "args": ["-m", "app.mcp.servers.time_server"], + "cwd": "/path/to/LuomiNest/backend" + }, + "luomi-weather": { + "command": "python", + "args": ["-m", "app.mcp.servers.weather_server"], + "cwd": "/path/to/LuomiNest/backend" + } + } + } + + # 编程方式调用 + from app.mcp.servers import create_time_server, create_weather_server + time_srv = create_time_server() + await time_srv() +""" + +from app.mcp.servers import create_time_server, create_weather_server + +__all__ = ["create_time_server", "create_weather_server"] diff --git a/backend/app/mcp/servers/__init__.py b/backend/app/mcp/servers/__init__.py new file mode 100644 index 0000000..69152f1 --- /dev/null +++ b/backend/app/mcp/servers/__init__.py @@ -0,0 +1,11 @@ +""" +MCP Server 子包 —— 存放所有标准 MCP Server 实现 + +每个 Server 是一个独立的 Python 模块, +可通过 stdio(标准输入/输出)以 JSON-RPC 2.0 协议与 MCP 客户端通信。 +""" + +from app.mcp.servers.time_server import create_time_server +from app.mcp.servers.weather_server import create_weather_server + +__all__ = ["create_time_server", "create_weather_server"] diff --git a/backend/app/mcp/servers/time_server.py b/backend/app/mcp/servers/time_server.py new file mode 100644 index 0000000..ea2f634 --- /dev/null +++ b/backend/app/mcp/servers/time_server.py @@ -0,0 +1,287 @@ +""" +时间 MCP Server —— 把现有 time_tool 封装为标准 MCP 工具服务 + +MCP 协议版本:2024-11-05 +通信方式:stdio(标准输入/输出),JSON-RPC 2.0 + +提供的工具: + - get_current_time:获取当前日期、时间、星期信息 + +设计原则: + 1. 纯 MCP 协议封装层,业务逻辑完全复用 time_tool.py 和 SkillRegistry + 2. 严格遵循 MCP 官方规范(tools/list + tools/call + initialize) + 3. 零外部依赖(仅用 Python 标准库),可在任意环境直接运行 + 4. 全链路异常处理,工具调用失败时返回结构化错误内容 + +用法(命令行启动): + python -m app.mcp.servers.time_server + +用法(Trae 配置 .trae/mcp.json): + { + "mcpServers": { + "luomi-time": { + "command": "python", + "args": ["-m", "app.mcp.servers.time_server"], + "cwd": "/path/to/LuomiNest/backend" + } + } + } +""" + +import sys +import json +import asyncio +from datetime import datetime +from loguru import logger + + +# ============================================================================= +# 工具定义(符合 MCP Tool 规范) +# ============================================================================= + +TOOLS = [ + { + "name": "get_current_time", + "description": ( + "获取当前日期和时间信息。返回包含日期、时间、星期、年、月、日、时、分、秒的详细数据。" + "用户询问'现在几点'、'今天几号'、'今天星期几'时使用此工具。" + ), + "inputSchema": { + "type": "object", + "properties": {}, + "required": [], + }, + }, +] + + +# ============================================================================= +# 工具调用处理 —— 复用现有 time_tool 逻辑 +# ============================================================================= + +def _call_get_current_time(arguments: dict) -> str: + """调用获取当前时间 —— 优先使用 time_tool,降级使用 SkillRegistry + + 两层降级策略: + 1. 尝试导入 TimeTool(自然语言格式回复,更友好) + 2. 降级到 SkillRegistry._builtin_get_time(结构化 JSON 回复) + 3. 最终兜底:纯 Python datetime + """ + # 第一层:使用 time_tool.py 的 TimeTool(自然语言回复) + try: + from app.utils.time_tool import TimeTool + tool = TimeTool(timezone="Asia/Shanghai") + return tool.get_reply("all") + except Exception as e: + logger.debug(f"[MCP-Time] TimeTool 不可用 ({e}),降级到 SkillRegistry") + + # 第二层:降级到 SkillRegistry 内置的时间获取 + try: + from app.runtime.plugin.skill.registry import SkillRegistry + import asyncio as _asyncio + try: + loop = _asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop and loop.is_running(): + logger.debug("[MCP-Time] 已在运行的事件循环中,跳过 SkillRegistry 异步调用") + else: + result = _asyncio.run(SkillRegistry._builtin_get_time()) + data = result.data if hasattr(result, "data") else result + weekday = data.get("weekday", "") + date = data.get("date", "") + time = data.get("time", "") + return f"日期:{date} {weekday},时间:{time}" + except Exception as e: + logger.debug(f"[MCP-Time] SkillRegistry 不可用 ({e}),降级到纯 datetime") + + # 第三层:最终兜底 —— 纯 Python 标准库 + now = datetime.now() + weekday_names = ["星期一", "星期二", "星期三", "星期四", "星期五", "星期六", "星期日"] + weekday = weekday_names[now.weekday()] + return ( + f"日期:{now.strftime('%Y-%m-%d')} {weekday}," + f"时间:{now.strftime('%H:%M:%S')}" + ) + + +# ============================================================================= +# MCP 协议处理 —— JSON-RPC 2.0 over stdio +# ============================================================================= + +def _build_response(request_id, result): + """构建成功响应""" + return json.dumps({ + "jsonrpc": "2.0", + "id": request_id, + "result": result, + }, ensure_ascii=False) + + +def _build_error(request_id, code: int, message: str): + """构建错误响应""" + return json.dumps({ + "jsonrpc": "2.0", + "id": request_id, + "error": { + "code": code, + "message": message, + }, + }, ensure_ascii=False) + + +def handle_request(request: dict) -> str | None: + """处理单个 JSON-RPC 请求,返回响应 JSON 字符串 + + 支持的 MCP 方法: + - initialize:握手初始化 + - tools/list:列出工具 + - tools/call:调用工具 + - notifications/initialized:初始化通知(无需响应) + + 参数: + request: 解析后的 JSON-RPC 请求字典 + + 返回: + JSON 响应字符串,通知类请求返回 None + """ + method = request.get("method", "") + request_id = request.get("id") + params = request.get("params", {}) + + # ----- initialize:MCP 握手 ----- + if method == "initialize": + return _build_response(request_id, { + "protocolVersion": "2024-11-05", + "capabilities": { + "tools": {}, + }, + "serverInfo": { + "name": "luomi-time-server", + "version": "1.0.0", + }, + }) + + # ----- notifications/initialized:初始化完成(无需响应)----- + if method == "notifications/initialized": + return None + + # ----- tools/list:返回工具列表 ----- + if method == "tools/list": + return _build_response(request_id, {"tools": TOOLS}) + + # ----- tools/call:执行工具调用 ----- + if method == "tools/call": + tool_name = params.get("name", "") + arguments = params.get("arguments", {}) + + if tool_name == "get_current_time": + try: + result_text = _call_get_current_time(arguments) + return _build_response(request_id, { + "content": [ + {"type": "text", "text": result_text}, + ], + }) + except Exception as e: + return _build_response(request_id, { + "content": [ + {"type": "text", "text": f"获取时间信息失败:{e}"}, + ], + "isError": True, + }) + else: + return _build_error(request_id, -32601, f"未知工具: {tool_name}") + + # ----- 未知方法 ----- + return _build_error(request_id, -32601, f"未知方法: {method}") + + +# ============================================================================= +# 主循环 —— stdio 通信 +# ============================================================================= + +async def run_server(): + """MCP Server 主循环 —— 从 stdin 读取 JSON-RPC 请求,处理后写到 stdout + + 使用 run_in_executor 将阻塞读取放到线程池, + 兼容 Windows 和所有平台。 + """ + loop = asyncio.get_running_loop() + + async def _read_line() -> str: + try: + line = await asyncio.wait_for( + loop.run_in_executor(None, sys.stdin.readline), + timeout=300, + ) + return line + except TimeoutError: + return "" + + while True: + try: + line = await _read_line() + except Exception: + break + + if not line: + break + + line = line.strip() + if not line: + continue + + # 解析 JSON-RPC 请求 + try: + request = json.loads(line) + except json.JSONDecodeError: + err = _build_error(None, -32700, "JSON 解析错误") + sys.stdout.write(err + "\n") + sys.stdout.flush() + continue + + # 处理请求 + try: + response = handle_request(request) + except Exception as e: + response = _build_error( + request.get("id"), + -32603, + f"内部错误: {e}", + ) + + # 输出响应 + if response is not None: + sys.stdout.write(response + "\n") + sys.stdout.flush() + + +def create_time_server(): + """工厂函数 —— 创建并返回时间 MCP Server 的启动函数 + + 返回: + run_server 协程,调用方可 await 启动服务 + """ + return run_server + + +# ============================================================================= +# 直接运行入口 +# ============================================================================= + +if __name__ == "__main__": + # 配置 loguru 输出到 stderr(避免污染 stdout 的 JSON-RPC 通信) + logger.remove() + logger.add(sys.stderr, level="INFO", format="{level} | {message}") + + logger.info("LuomiNest 时间 MCP Server 启动中...") + logger.info("工具列表: get_current_time") + logger.info("协议版本: 2024-11-05") + try: + asyncio.run(run_server()) + except KeyboardInterrupt: + logger.info("时间 MCP Server 已停止") + except Exception as e: + logger.error(f"时间 MCP Server 异常退出: {e}") + sys.exit(1) diff --git a/backend/app/mcp/servers/weather_server.py b/backend/app/mcp/servers/weather_server.py new file mode 100644 index 0000000..b9563a1 --- /dev/null +++ b/backend/app/mcp/servers/weather_server.py @@ -0,0 +1,346 @@ +""" +天气 MCP Server —— 把现有天气爬虫工具封装为标准 MCP 工具服务 + +MCP 协议版本:2024-11-05 +通信方式:stdio(标准输入/输出),JSON-RPC 2.0 + +提供的工具: + - get_weather_info:获取指定城市的天气信息(含温度、天气状况、出行建议) + +设计原则: + 1. 纯 MCP 协议封装层,业务逻辑完全复用 weather.py 的 get_weather 函数 + 2. 严格遵循 MCP 官方规范(tools/list + tools/call + initialize) + 3. 支持异步天气查询(Open-Meteo API),带本地缓存(复用 weather.py 缓存) + 4. 全链路异常处理,工具调用失败时返回结构化错误内容 + +用法(命令行启动): + python -m app.mcp.servers.weather_server + +用法(Trae 配置 .trae/mcp.json): + { + "mcpServers": { + "luomi-weather": { + "command": "python", + "args": ["-m", "app.mcp.servers.weather_server"], + "cwd": "/path/to/LuomiNest/backend" + } + } + } +""" + +import sys +import json +import asyncio +from datetime import datetime, timedelta +from loguru import logger + + +# ============================================================================= +# 工具定义(符合 MCP Tool 规范) +# ============================================================================= + +TOOLS = [ + { + "name": "get_weather_info", + "description": ( + "获取指定城市的天气信息。返回城市名称、天气状况、气温、出行建议。" + "当用户询问'天气怎么样'、'会不会下雨'、'穿什么衣服'、'气温多少度'时使用此工具。" + "支持查询今天、明天、后天的天气。" + ), + "inputSchema": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "城市名称,如:北京、上海、广州、深圳", + }, + "date": { + "type": "string", + "description": "日期,如:今天、明天、后天、2026-05-03,默认为今天", + }, + }, + "required": ["city"], + }, + }, +] + + +# ============================================================================= +# 日期参数解析 +# ============================================================================= + +def _resolve_date(date_input: str) -> str: + """将自然语言日期(今天/明天/后天)转为 YYYY-MM-DD 格式 + + 参数: + date_input: 日期输入,可为 "今天"、"明天"、"后天"、YYYY-MM-DD 或空 + + 返回: + YYYY-MM-DD 格式的日期字符串 + """ + if not date_input or date_input in ["今天", "今日", ""]: + return datetime.now().strftime("%Y-%m-%d") + if date_input in ["明天", "明日"]: + return (datetime.now() + timedelta(days=1)).strftime("%Y-%m-%d") + if date_input in ["后天"]: + return (datetime.now() + timedelta(days=2)).strftime("%Y-%m-%d") + # 已经是 YYYY-MM-DD 格式或其它格式,原样返回 + return date_input + + +# ============================================================================= +# 工具调用处理 —— 复用现有天气爬虫逻辑 +# ============================================================================= + +async def _call_get_weather_info(arguments: dict) -> str: + """调用获取天气信息 —— 复用 weather.py 的 get_weather 和 _format_weather_for_user + + 三层降级策略: + 1. 尝试导入 weather.py 的 get_weather(完整 Open-Meteo API + 缓存) + 2. 降级到 SkillRegistry(如果天气已注册) + 3. 最终兜底:返回结构化错误信息 + + 参数: + arguments: {"city": "北京", "date": "今天"} + + 返回: + 格式化的天气信息自然语言字符串 + """ + city = arguments.get("city", "") + date_input = arguments.get("date", "") + + if not city: + return "请提供城市名称,如:北京、上海、广州" + + date = _resolve_date(date_input) + + # 第一层:使用 weather.py 的 get_weather(完整 API + 缓存) + try: + from app.runtime.plugin.skill.builtin.weather import ( + get_weather, + _format_weather_for_user, + ) + result = await get_weather(city=city, date=date_input) + if result.success and result.data: + # weather.py 的 get_weather 已在内部调用了 _format_weather_for_user + formatted = result.data.get("formatted", "") + if formatted: + return formatted + # 回退:手动格式化 + return _format_weather_for_user(result.data) + return result.error or "未能获取天气数据" + except Exception as e: + logger.debug(f"[MCP-Weather] weather.py 不可用 ({e}),降级到 SkillRegistry") + + # 第二层:降级到 SkillRegistry + try: + from app.runtime.plugin.skill.registry import SkillRegistry + handler = SkillRegistry.get_handler("get_weather") + if handler: + result = await handler(city=city, date=date_input) + if hasattr(result, "to_text"): + return result.to_text() + return str(result) + return f"天气工具未注册,无法获取 {city} 的天气信息。" + except Exception as e: + logger.debug(f"[MCP-Weather] SkillRegistry 不可用 ({e}),降级到兜底") + + # 第三层:最终兜底 + return f"暂时无法获取 {city}({date})的天气数据,建议查看天气预报应用获取最新信息。" + + +# ============================================================================= +# MCP 协议处理 —— JSON-RPC 2.0 over stdio +# ============================================================================= + +def _build_response(request_id, result): + """构建成功响应""" + return json.dumps({ + "jsonrpc": "2.0", + "id": request_id, + "result": result, + }, ensure_ascii=False) + + +def _build_error(request_id, code: int, message: str): + """构建错误响应""" + return json.dumps({ + "jsonrpc": "2.0", + "id": request_id, + "error": { + "code": code, + "message": message, + }, + }, ensure_ascii=False) + + +def handle_request(request: dict) -> str | None: + """处理单个 JSON-RPC 请求,返回响应 JSON 字符串 + + 支持的 MCP 方法: + - initialize:握手初始化 + - tools/list:列出工具 + - tools/call:调用工具 + - notifications/initialized:初始化通知(无需响应) + + 注意:tools/call 中的天气查询是异步的,需要在主循环中 await。 + 此函数返回一个特殊标记 "ASYNC_WEATHER",主循环检测到后执行异步调用。 + + 参数: + request: 解析后的 JSON-RPC 请求字典 + + 返回: + JSON 响应字符串,通知类请求返回 None, + tools/call 天气请求返回 ("ASYNC_WEATHER", request_id, arguments) 元组 + """ + method = request.get("method", "") + request_id = request.get("id") + params = request.get("params", {}) + + # ----- initialize:MCP 握手 ----- + if method == "initialize": + return _build_response(request_id, { + "protocolVersion": "2024-11-05", + "capabilities": { + "tools": {}, + }, + "serverInfo": { + "name": "luomi-weather-server", + "version": "1.0.0", + }, + }) + + # ----- notifications/initialized:初始化完成(无需响应)----- + if method == "notifications/initialized": + return None + + # ----- tools/list:返回工具列表 ----- + if method == "tools/list": + return _build_response(request_id, {"tools": TOOLS}) + + # ----- tools/call:执行工具调用 ----- + if method == "tools/call": + tool_name = params.get("name", "") + arguments = params.get("arguments", {}) + + if tool_name == "get_weather_info": + # 天气查询是异步的,返回特殊标记让主循环处理 + return ("ASYNC_WEATHER", request_id, arguments) + else: + return _build_error(request_id, -32601, f"未知工具: {tool_name}") + + # ----- 未知方法 ----- + return _build_error(request_id, -32601, f"未知方法: {method}") + + +async def _handle_weather_call(request_id, arguments: dict) -> str: + """处理异步天气查询并构建响应""" + try: + result_text = await _call_get_weather_info(arguments) + return _build_response(request_id, { + "content": [ + {"type": "text", "text": result_text}, + ], + }) + except Exception as e: + return _build_response(request_id, { + "content": [ + {"type": "text", "text": f"获取天气信息失败:{e}"}, + ], + "isError": True, + }) + + +# ============================================================================= +# 主循环 —— stdio 通信 +# ============================================================================= + +async def run_server(): + """MCP Server 主循环 —— 从 stdin 读取 JSON-RPC 请求,处理后写到 stdout + + 天气查询是异步的(httpx 请求 Open-Meteo API), + 在主循环中 await 确保非阻塞执行。 + """ + loop = asyncio.get_event_loop() + + async def _read_line() -> str: + try: + line = await asyncio.wait_for( + loop.run_in_executor(None, sys.stdin.readline), + timeout=300, + ) + return line + except asyncio.TimeoutError: + return "" + + while True: + try: + line = await _read_line() + except Exception: + break + + if not line: + break + + line = line.strip() + if not line: + continue + + # 解析 JSON-RPC 请求 + try: + request = json.loads(line) + except json.JSONDecodeError: + err = _build_error(None, -32700, "JSON 解析错误") + sys.stdout.write(err + "\n") + sys.stdout.flush() + continue + + # 处理请求 + try: + response = handle_request(request) + except Exception as e: + response = _build_error( + request.get("id"), + -32603, + f"内部错误: {e}", + ) + + # 处理异步天气调用 + if isinstance(response, tuple) and response[0] == "ASYNC_WEATHER": + _, req_id, args = response + response = await _handle_weather_call(req_id, args) + + # 输出响应 + if response is not None: + sys.stdout.write(response + "\n") + sys.stdout.flush() + + +def create_weather_server(): + """工厂函数 —— 创建并返回天气 MCP Server 的启动函数 + + 返回: + run_server 协程,调用方可 await 启动服务 + """ + return run_server + + +# ============================================================================= +# 直接运行入口 +# ============================================================================= + +if __name__ == "__main__": + # 配置 loguru 输出到 stderr(避免污染 stdout 的 JSON-RPC 通信) + logger.remove() + logger.add(sys.stderr, level="INFO", format="{level} | {message}") + + logger.info("LuomiNest 天气 MCP Server 启动中...") + logger.info("工具列表: get_weather_info") + logger.info("协议版本: 2024-11-05") + try: + asyncio.run(run_server()) + except KeyboardInterrupt: + logger.info("天气 MCP Server 已停止") + except Exception as e: + logger.error(f"天气 MCP Server 异常退出: {e}") + sys.exit(1) diff --git a/backend/app/mcp/tests/__init__.py b/backend/app/mcp/tests/__init__.py new file mode 100644 index 0000000..9e5944b --- /dev/null +++ b/backend/app/mcp/tests/__init__.py @@ -0,0 +1,3 @@ +""" +MCP 测试包 +""" diff --git a/backend/app/mcp/tests/test_mcp_protocol.py b/backend/app/mcp/tests/test_mcp_protocol.py new file mode 100644 index 0000000..23c3d58 --- /dev/null +++ b/backend/app/mcp/tests/test_mcp_protocol.py @@ -0,0 +1,205 @@ +""" +MCP Server 协议兼容性验证脚本(手动模拟 stdio) + +直接调用各 Server 的 handle_request 函数,模拟完整的 MCP 协议交互: +initialize → tools/list → tools/call + +不依赖 subprocess,适用于 Windows 沙箱环境。 + +用法: + python -m app.mcp.tests.test_mcp_protocol +""" + +import sys +import os +import json +import asyncio +import time + + +def simulate_client(handle_fn, server_name: str, icon: str, extra_async_handler=None): + """模拟一个 MCP 客户端与 Server 通信 + + 通过直接调用 handle_fn 发送 JSON-RPC 请求,无需子进程。 + + 参数: + handle_fn: Server 的请求处理函数 + server_name: 服务器名称 + icon: 图标(emoji) + extra_async_handler: 异步回调(用于天气查询等需要 await 的调用) + """ + next_id = 1 + results = [] + start = time.time() + + def _send(method, params=None, expect_content=True): + nonlocal next_id + request = { + "jsonrpc": "2.0", + "id": next_id, + "method": method, + "params": params or {}, + } + next_id += 1 + + response_str = handle_fn(request) + if response_str is None: + return None + + # 处理异步标记(天气查询返回 ("ASYNC_WEATHER", id, args) 元组) + if isinstance(response_str, tuple) and response_str[0] == "ASYNC_WEATHER": + return {"_async": True, "_id": response_str[1], "_args": response_str[2]} + + response = json.loads(response_str) + + if "error" in response: + return response + + if expect_content and "result" not in response: + raise AssertionError(f"{method} 缺少 result: {response}") + + return response + + try: + # ===== 1. initialize 握手 ===== + print(f"\n [1/4] {server_name} initialize 握手...") + resp = _send("initialize", { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test-client", "version": "1.0"}, + }) + assert resp["result"]["protocolVersion"] == "2024-11-05" + assert "tools" in resp["result"]["capabilities"] + svr_info = resp["result"]["serverInfo"] + print(f" 协议版本: {resp['result']['protocolVersion']}") + print(f" 服务名称: {svr_info['name']} v{svr_info['version']}") + results.append("PASS") + + # ===== 2. 初始化完成通知 ===== + resp = _send("notifications/initialized", expect_content=False) + assert resp is None + results.append("PASS") + + # ===== 3. tools/list ===== + print(f"\n [2/4] {server_name} tools/list 查询工具列表...") + resp = _send("tools/list") + tools = resp["result"]["tools"] + assert len(tools) >= 1 + tool_names = [t["name"] for t in tools] + print(f" 工具数量: {len(tools)}") + print(f" 工具列表: {tool_names}") + + for tool in tools: + assert "name" in tool, f"工具缺少 name" + assert "description" in tool, f"工具缺少 description" + assert "inputSchema" in tool, f"工具缺少 inputSchema" + assert tool["inputSchema"]["type"] == "object" + print(f" - {tool['name']}: {tool['description'][:50]}...") + + results.append("PASS") + + # ===== 4. tools/call ===== + print(f"\n [3/4] {server_name} tools/call 调用工具...") + first_tool = tools[0] + call_args = {} + + # 如果工具需要参数,提供默认值 + if "properties" in first_tool["inputSchema"]: + for prop_name, prop_info in first_tool["inputSchema"]["properties"].items(): + if prop_name == "city": + call_args["city"] = "北京" + elif prop_name == "date": + call_args["date"] = "今天" + + resp = _send("tools/call", { + "name": first_tool["name"], + "arguments": call_args, + }) + + # 处理异步响应(天气查询返回 {"_async": True} 标记) + if isinstance(resp, dict) and resp.get("_async"): + async def _do_async(): + return await extra_async_handler(resp["_id"], resp["_args"]) + resp = asyncio.run(_do_async()) + resp = json.loads(resp) + + content = resp["result"]["content"] + assert len(content) >= 1 + assert content[0]["type"] == "text" + assert content[0]["text"] + print(f" 返回内容: {content[0]['text'][:100]}...") + results.append("PASS") + + # ===== 5. 错误处理:未知工具 ===== + print(f"\n [4/4] {server_name} 错误处理: 调用未知工具 unknown_tool_xyz...") + resp = _send("tools/call", { + "name": "unknown_tool_xyz", + "arguments": {}, + }, expect_content=False) + assert "error" in resp, f"未知工具应返回 error" + assert resp["error"]["code"] == -32601 + print(f" 错误码: {resp['error']['code']}, 消息: {resp['error']['message']}") + results.append("PASS") + + except Exception as e: + print(f"\n {icon} 失败: {e}") + results.append(f"FAIL: {e}") + import traceback + traceback.print_exc() + + elapsed = time.time() - start + all_pass = all(r == "PASS" for r in results) + pass_count = results.count("PASS") + total_count = len(results) + status_msg = "全部通过" if all_pass else f"{pass_count}/{total_count} 通过" + print(f"\n {server_name} 耗时: {elapsed:.3f}s, 结果: {status_msg}") + + return all_pass + + +def test_time(): + """测试时间 MCP Server""" + from app.mcp.servers.time_server import handle_request + return simulate_client(handle_request, "时间", "(time)") + + +def test_weather(): + """测试天气 MCP Server(需要网络连接 Open-Meteo API)""" + from app.mcp.servers.weather_server import ( + handle_request, + _handle_weather_call, + ) + return simulate_client(handle_request, "天气", "(weather)", _handle_weather_call) + + +def main(): + print("=" * 70) + print(" LuomiNest MCP Server 协议兼容性验证") + print(" MCP 协议: 2024-11-05 | JSON-RPC 2.0 | stdio") + print("=" * 70) + + # 运行测试 + time_ok = test_time() + weather_ok = test_weather() + + # 汇总 + print() + print("=" * 70) + print(" 最终结果") + print("=" * 70) + print(f" 时间 MCP Server: {'PASS' if time_ok else 'FAIL'}") + print(f" 天气 MCP Server: {'PASS' if weather_ok else 'FAIL'}") + + if time_ok and weather_ok: + print() + print(" 全部 MCP Server 测试通过!") + print(" 这些 Server 现在可被任何兼容 MCP 的客户端加载使用。") + else: + print() + print(" 部分测试未通过,请检查失败项。") + + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/backend/app/runtime/__pycache__/__init__.cpython-313.pyc b/backend/app/runtime/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index 347eb4b..0000000 Binary files a/backend/app/runtime/__pycache__/__init__.cpython-313.pyc and /dev/null differ diff --git a/backend/app/runtime/context.py b/backend/app/runtime/context.py deleted file mode 100644 index 2df40d0..0000000 --- a/backend/app/runtime/context.py +++ /dev/null @@ -1,57 +0,0 @@ -from dataclasses import dataclass, field -from typing import Any, Optional -from datetime import datetime -from enum import Enum - - -class MessageRole(str, Enum): - USER = "user" - ASSISTANT = "assistant" - SYSTEM = "system" - TOOL = "tool" - - -@dataclass -class Message: - role: MessageRole - content: str - timestamp: datetime = field(default_factory=datetime.now) - metadata: dict[str, Any] = field(default_factory=dict) - - -@dataclass -class UserContext: - user_id: str - username: str - session_id: str - device_id: Optional[str] = None - location: Optional[str] = None - permissions: list[str] = field(default_factory=list) - - -@dataclass -class AgentResult: - text: str = "" - actions: list[dict[str, Any]] = field(default_factory=list) - emotion: str = "neutral" - avatar_expression: str = "normal" - tool_calls: list[dict[str, Any]] = field(default_factory=list) - metadata: dict[str, Any] = field(default_factory=dict) - - -@dataclass -class PipelineContext: - messages: list[Message] = field(default_factory=list) - user_context: Optional[UserContext] = None - agent_result: Optional[AgentResult] = None - raw_input: str = "" - intent: str = "" - matched_agent: Optional[str] = None - memory_recall: list[dict] = field(default_factory=list) - tool_results: list[dict] = field(default_factory=list) - emotion_analysis: dict[str, Any] = field(default_factory=dict) - should_stop: bool = False - extra: dict[str, Any] = field(default_factory=dict) - - def add_message(self, role: MessageRole, content: str, **kwargs) -> None: - self.messages.append(Message(role=role, content=content, **kwargs)) diff --git a/backend/app/runtime/context/__init__.py b/backend/app/runtime/context/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/runtime/context/pipeline_context.py b/backend/app/runtime/context/pipeline_context.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/runtime/event_bus/__init__.py b/backend/app/runtime/event_bus/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/runtime/event_bus/__pycache__/__init__.cpython-313.pyc b/backend/app/runtime/event_bus/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index 2c09909..0000000 Binary files a/backend/app/runtime/event_bus/__pycache__/__init__.cpython-313.pyc and /dev/null differ diff --git a/backend/app/runtime/event_bus/core.py b/backend/app/runtime/event_bus/core.py deleted file mode 100644 index 54bdd1e..0000000 --- a/backend/app/runtime/event_bus/core.py +++ /dev/null @@ -1,61 +0,0 @@ -import asyncio -from enum import Enum -from dataclasses import dataclass, field -from typing import Any, Callable, Awaitable -from loguru import logger - - -class EventType(str, Enum): - USER_MESSAGE = "user.message" - AGENT_RESPONSE = "agent.response" - DEVICE_STATUS = "device.status" - DEVICE_COMMAND = "device.command" - PLUGIN_EVENT = "plugin.event" - SYSTEM_BROADCAST = "system.broadcast" - LOCATION_UPDATE = "location.update" - AUDIO_STREAM = "audio.stream" - - -@dataclass -class Event: - type: EventType - payload: dict[str, Any] - source: str = "" - target: str = "" - tenant_id: str = "default" - metadata: dict[str, Any] = field(default_factory=dict) - - -EventHandler = Callable[[Event], Awaitable[None]] - - -class EventBus: - def __init__(self): - self._handlers: dict[EventType, list[EventHandler]] = {} - self._queue: asyncio.Queue[Event] = asyncio.Queue() - self._running = False - - def on(self, event_type: EventType) -> Callable[[EventHandler], EventHandler]: - def decorator(handler: EventHandler) -> EventHandler: - if event_type not in self._handlers: - self._handlers[event_type] = [] - self._handlers[event_type].append(handler) - return handler - return decorator - - async def emit(self, event: Event) -> None: - await self._queue.put(event) - handlers = self._handlers.get(event.type, []) - for handler in handlers: - try: - await handler(event) - except Exception as e: - logger.error(f"Event handler error [{event.type}]: {e}") - - async def start(self) -> None: - self._running = True - logger.info("EventBus started") - - async def stop(self) -> None: - self._running = False - logger.info("EventBus stopped") diff --git a/backend/app/runtime/event_bus/subscriber.py b/backend/app/runtime/event_bus/subscriber.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/runtime/pipeline/__init__.py b/backend/app/runtime/pipeline/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/runtime/pipeline/__pycache__/__init__.cpython-313.pyc b/backend/app/runtime/pipeline/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index ff081d8..0000000 Binary files a/backend/app/runtime/pipeline/__pycache__/__init__.cpython-313.pyc and /dev/null differ diff --git a/backend/app/runtime/pipeline/base.py b/backend/app/runtime/pipeline/base.py deleted file mode 100644 index eb9ad51..0000000 --- a/backend/app/runtime/pipeline/base.py +++ /dev/null @@ -1,23 +0,0 @@ -from abc import ABC, abstractmethod -from typing import AsyncGenerator -from app.runtime.context import PipelineContext -from loguru import logger - - -class PipelineStage(ABC): - @abstractmethod - async def process(self, ctx: PipelineContext) -> PipelineContext | None: - pass - - @property - @abstractmethod - def name(self) -> str: - pass - - @property - def order(self) -> int: - return 0 - - -class StopPropagation(Exception): - pass diff --git a/backend/app/runtime/pipeline/context.py b/backend/app/runtime/pipeline/context.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/runtime/pipeline/engine.py b/backend/app/runtime/pipeline/engine.py deleted file mode 100644 index a8a26ce..0000000 --- a/backend/app/runtime/pipeline/engine.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Optional -from app.runtime.pipeline.base import PipelineStage -from app.runtime.pipeline.stages import ( - WakeWordStage, - AuthStage, - RateLimitStage, - SessionStage, - ContextBuildStage, - PreprocessStage, - AgentRouteStage, - LLMInferenceStage, - ToolExecuteStage, - MemoryExtractStage, - EmotionAnalysisStage, - ResponseDecorateStage, - MultiDispatchStage, - AuditLogStage, -) -from app.runtime.event_bus import EventBus, EventType -from app.runtime.context import PipelineContext -from loguru import logger - - -STAGE_REGISTRY = [ - WakeWordStage, - AuthStage, - RateLimitStage, - SessionStage, - ContextBuildStage, - PreprocessStage, - AgentRouteStage, - LLMInferenceStage, - ToolExecuteStage, - MemoryExtractStage, - EmotionAnalysisStage, - ResponseDecorateStage, - MultiDispatchStage, - AuditLogStage, -] - - -class Pipeline: - def __init__(self, event_bus: EventBus): - self.event_bus = event_bus - self.stages: list[PipelineStage] = sorted( - [s() for s in STAGE_REGISTRY], key=lambda x: x.order - ) - - async def execute(self, ctx: PipelineContext) -> PipelineContext: - for stage in self.stages: - try: - logger.debug(f"Pipeline stage: {stage.name}") - result = await stage.process(ctx) - if result is None or ctx.should_stop: - break - except Exception as e: - logger.error(f"Pipeline stage [{stage.name}] failed: {e}") - raise - - await self.event_bus.emit( - EventType.AGENT_RESPONSE, - {"result": ctx.agent_result, "session_id": ctx.user_context.session_id if ctx.user_context else ""}, - ) - return ctx diff --git a/backend/app/runtime/pipeline/stages/01_wake_word.py b/backend/app/runtime/pipeline/stages/01_wake_word.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/runtime/pipeline/stages/02_auth.py b/backend/app/runtime/pipeline/stages/02_auth.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/runtime/pipeline/stages/03_rate_limit.py b/backend/app/runtime/pipeline/stages/03_rate_limit.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/runtime/pipeline/stages/04_session.py b/backend/app/runtime/pipeline/stages/04_session.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/runtime/pipeline/stages/05_context_build.py b/backend/app/runtime/pipeline/stages/05_context_build.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/runtime/pipeline/stages/06_preprocess.py b/backend/app/runtime/pipeline/stages/06_preprocess.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/runtime/pipeline/stages/07_agent_route.py b/backend/app/runtime/pipeline/stages/07_agent_route.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/runtime/pipeline/stages/08_llm_inference.py b/backend/app/runtime/pipeline/stages/08_llm_inference.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/runtime/pipeline/stages/09_tool_execute.py b/backend/app/runtime/pipeline/stages/09_tool_execute.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/runtime/pipeline/stages/10_memory_extract.py b/backend/app/runtime/pipeline/stages/10_memory_extract.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/runtime/pipeline/stages/11_emotion_analysis.py b/backend/app/runtime/pipeline/stages/11_emotion_analysis.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/runtime/pipeline/stages/12_response_decorate.py b/backend/app/runtime/pipeline/stages/12_response_decorate.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/runtime/pipeline/stages/13_multi_dispatch.py b/backend/app/runtime/pipeline/stages/13_multi_dispatch.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/runtime/pipeline/stages/14_audit_log.py b/backend/app/runtime/pipeline/stages/14_audit_log.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/runtime/pipeline/stages/__init__.py b/backend/app/runtime/pipeline/stages/__init__.py deleted file mode 100644 index dbfe44f..0000000 --- a/backend/app/runtime/pipeline/stages/__init__.py +++ /dev/null @@ -1,217 +0,0 @@ -from app.runtime.pipeline.base import PipelineStage, StopPropagation -from app.runtime.context import PipelineContext, MessageRole -from loguru import logger - - -class WakeWordStage(PipelineStage): - @property - def name(self) -> str: - return "wake_word" - - @property - def order(self) -> int: - return 1 - - async def process(self, ctx: PipelineContext) -> PipelineContext: - wake_words = ["小洛", "luomi", "罗米"] - is_wake = any(ctx.raw_input.lower().startswith(ww.lower()) for ww in wake_words) - if is_wake: - ctx.raw_input = ctx.raw_input.strip() - for ww in wake_words: - if ctx.raw_input.lower().startswith(ww.lower()): - ctx.raw_input = ctx.raw_input[len(ww):].strip() - break - logger.debug(f"Wake word detected: {is_wake}") - return ctx - - -class AuthStage(PipelineStage): - @property - def name(self) -> str: - return "authentication" - - @property - def order(self) -> int: - return 2 - - async def process(self, ctx: PipelineContext) -> PipelineContext: - from app.security.auth import verify_token - token = ctx.extra.get("token", "") - if token: - user = await verify_token(token) - if user and ctx.user_context: - ctx.user_context.user_id = user.id - ctx.user_context.permissions = user.permissions - return ctx - - -class RateLimitStage(PipelineStage): - @property - def name(self) -> str: - return "rate_limit" - - @property - def order(self) -> int: - return 3 - - async def process(self, ctx: PipelineContext) -> PipelineContext: - return ctx - - -class SessionStage(PipelineStage): - @property - def name(self) -> str: - return "session" - - @property - def order(self) -> int: - return 4 - - async def process(self, ctx: PipelineContext) -> PipelineContext: - ctx.add_message(MessageRole.USER, ctx.raw_input) - return ctx - - -class ContextBuildStage(PipelineStage): - @property - def name(self) -> str: - return "context_build" - - @property - def order(self) -> int: - return 5 - - async def process(self, ctx: PipelineContext) -> PipelineContext: - return ctx - - -class PreprocessStage(PipelineStage): - @property - def name(self) -> str: - return "preprocess" - - @property - def order(self) -> int: - return 6 - - async def process(self, ctx: PipelineContext) -> PipelineContext: - return ctx - - -class AgentRouteStage(PipelineStage): - @property - def name(self) -> str: - return "agent_route" - - @property - def order(self) -> int: - return 7 - - async def process(self, ctx: PipelineContext) -> PipelineContext: - from app.domains.agent.router import IntentClassifier - classifier = IntentClassifier() - ctx.intent = await classifier.classify(ctx.raw_input) - ctx.matched_agent = classifier.select_agent(ctx.intent) - return ctx - - -class LLMInferenceStage(PipelineStage): - @property - def name(self) -> str: - return "llm_inference" - - @property - def order(self) -> int: - return 8 - - async def process(self, ctx: PipelineContext) -> PipelineContext: - from app.runtime.provider.llm.adapter import LLMAdapter - adapter = LLMAdapter() - messages = [{"role": m.role.value, "content": m.content} for m in ctx.messages] - response_text = await adapter.chat(messages) - ctx.agent_result.__class__.__dict__.setdefault('text', response_text) - object.__setattr__(ctx.agent_result if ctx.agent_result else object(), 'text', response_text) - if ctx.agent_result is None: - from app.runtime.context import AgentResult - ctx.agent_result = AgentResult(text=response_text) - else: - ctx.agent_result.text = response_text - return ctx - - -class ToolExecuteStage(PipelineStage): - @property - def name(self) -> str: - return "tool_execute" - - @property - def order(self) -> int: - return 9 - - async def process(self, ctx: PipelineContext) -> PipelineContext: - return ctx - - -class MemoryExtractStage(PipelineStage): - @property - def name(self) -> str: - return "memory_extract" - - @property - def order(self) -> int: - return 10 - - async def process(self, ctx: PipelineContext) -> PipelineContext: - return ctx - - -class EmotionAnalysisStage(PipelineStage): - @property - def name(self) -> str: - return "emotion_analysis" - - @property - def order(self) -> int: - return 11 - - async def process(self, ctx: PipelineContext) -> PipelineContext: - return ctx - - -class ResponseDecorateStage(PipelineStage): - @property - def name(self) -> str: - return "response_decorate" - - @property - def order(self) -> int: - return 12 - - async def process(self, ctx: PipelineContext) -> PipelineContext: - return ctx - - -class MultiDispatchStage(PipelineStage): - @property - def name(self) -> str: - return "multi_dispatch" - - @property - def order(self) -> int: - return 13 - - async def process(self, ctx: PipelineContext) -> PipelineContext: - return ctx - - -class AuditLogStage(PipelineStage): - @property - def name(self) -> str: - return "audit_log" - - @property - def order(self) -> int: - return 14 - - async def process(self, ctx: PipelineContext) -> PipelineContext: - return ctx diff --git a/backend/app/runtime/plugin/skill/registry.py b/backend/app/runtime/plugin/skill/registry.py index 0aa1eb7..8330be5 100644 --- a/backend/app/runtime/plugin/skill/registry.py +++ b/backend/app/runtime/plugin/skill/registry.py @@ -95,6 +95,32 @@ def _register_builtins(cls): handler=cls._builtin_get_time, ) + # 天气工具:对接 app/runtime/plugin/skill/builtin/weather.py 的 get_weather + cls.register( + SkillDefinition( + name="get_weather", + description="获取指定城市的天气信息,包含温度、天气状况、风力、出行建议。当用户明确询问天气、气温、穿什么衣服、是否会下雨时使用。", + category="utility", + parameters={ + "city": { + "type": "string", + "description": "城市名称,如:北京、上海、广州", + "required": True, + }, + "date": { + "type": "string", + "description": "日期,如:今天、明天、后天、2026-05-06,可选,默认今天", + "required": False, + }, + }, + is_active=True, + is_builtin=True, + handler_name="get_weather", + tags=["weather", "天气", "utility"], + ), + handler=cls._builtin_get_weather, + ) + cls.register( SkillDefinition( name="transfer_to_agent", @@ -246,3 +272,55 @@ async def _builtin_transfer_agent(cls, **kwargs) -> 'SkillResult': metadata={"transfer": True, "target_agent_id": agent["id"]}, ) return SkillResult(success=False, error=f"Agent '{agent_name}' not found") + + @classmethod + async def _builtin_get_weather(cls, **kwargs) -> 'SkillResult': + """内置天气工具 handler —— 对接 weather_tool.py 的完整 API + 缓存 + 日期解析 + + 流程: + 1. 提取调用参数中的城市名和日期 + 2. 若城市名为空,返回引导用户补充的提示 + 3. 若日期为空,默认查询今天 + 4. 调用 weather_tool 的天气工具获取数据 + 5. 全链路异常捕获,返回友好兜底,绝不暴露技术细节 + + 参数: + city: 城市名称(必填,LLM 从用户消息中提取) + date: 日期(可选,如"明天"、"5.1号"、"下周一",默认今天) + + 返回: + SkillResult,成功时 data 含 formatted 自然语言回复, + 失败时 error 为友好兜底话术。 + """ + from app.runtime.plugin.skill.base import SkillResult + + city_raw = kwargs.get("city", "") + date_raw = kwargs.get("date", "") + + # 城市名校验与清洗 + city = city_raw.strip() if city_raw else "" + # 去掉"市"后缀,如"北京市"→"北京" + if city.endswith("市") and len(city) > 1: + city = city[:-1] + + if not city: + return SkillResult( + success=False, + error="请告诉我你想查询哪个城市的天气,比如'北京天气怎么样'。" + ) + + # 日期清洗 + date_str = date_raw.strip() if date_raw else "" + + try: + # 调用 weather_tool 的核心接口(含日期解析 + 缓存 + 口语化回复) + from app.utils.weather_tool import _weather_tool + reply = await _weather_tool.get_reply(city=city, date_str=date_str) + return SkillResult(success=True, data={"formatted": reply}) + except Exception as e: + logger.warning(f"[SkillRegistry] _builtin_get_weather 异常: {e}") + return SkillResult( + success=False, + error=f"很抱歉,暂时无法为你获取「{city}」的实时天气数据。" + f"你可以打开手机自带的天气APP,或通过搜索引擎输入「{city} 今日天气」快速查询~" + ) diff --git a/backend/app/runtime/provider/base.py b/backend/app/runtime/provider/base.py index f8153b1..4c50276 100644 --- a/backend/app/runtime/provider/base.py +++ b/backend/app/runtime/provider/base.py @@ -32,6 +32,12 @@ async def embed(self, text: str) -> list[float]: async def list_models(self) -> list[dict]: pass + def supports_tool_calls(self, model: str = "") -> bool: + return False + + def supports_multimodal(self, model: str = "") -> bool: + return False + class STTProvider(ABC): provider_name: str = "base" diff --git a/backend/app/runtime/provider/llm/adapter.py b/backend/app/runtime/provider/llm/adapter.py index fd5428d..7e39cbc 100644 --- a/backend/app/runtime/provider/llm/adapter.py +++ b/backend/app/runtime/provider/llm/adapter.py @@ -59,7 +59,7 @@ def _create_provider_from_config(config: dict) -> OpenAICompatibleProvider: if not api_key: api_key = "ollama" if not default_model: - default_model = "qwen2.5:7b" + default_model = "qwen3-vl:8b" provider_name = "ollama" else: if not base_url: @@ -163,14 +163,22 @@ def get_provider(self, name: str | None = None) -> OpenAICompatibleProvider: def get_provider_config(self, name: str) -> dict | None: return self._provider_configs.get(name) + def supports_tool_calls(self, provider_name: str | None = None, model: str = "") -> bool: + try: + provider = self.get_provider(provider_name) + return provider.supports_tool_calls(model) + except ProviderError: + return False + async def chat( self, messages: list[dict], tools: list[dict] | None = None, stream: bool = False, provider_name: str | None = None, + return_raw: bool = False, **kwargs - ) -> str | AsyncIterator[str]: + ) -> str | dict | AsyncIterator[dict]: provider = self.get_provider(provider_name) actual_provider = provider_name or self.default_provider model = kwargs.get("model") or provider.default_model @@ -178,18 +186,19 @@ async def chat( start_time = time.time() try: - result = await provider.chat(messages, tools, stream, **kwargs) + result = await provider.chat(messages, tools, stream, return_raw=return_raw, **kwargs) elapsed = time.time() - start_time if isinstance(result, str): logger.success(f"[LLM] Chat response: provider={actual_provider}, elapsed={elapsed:.2f}s, len={len(result)}") - else: + elif hasattr(result, '__aiter__'): logger.info(f"[LLM] Chat stream started: provider={actual_provider}") + else: + reasoning_len = len(result.get("reasoning", "")) if isinstance(result, dict) else 0 + logger.success(f"[LLM] Chat response: provider={actual_provider}, elapsed={elapsed:.2f}s, reasoning={reasoning_len}") return result except Exception as e: elapsed = time.time() - start_time logger.error(f"[LLM] Chat failed: provider={actual_provider}, elapsed={elapsed:.2f}s, error={e}") - if provider_name: - raise ProviderError(f"Provider [{provider_name}] failed: {e}", provider=provider_name) return await self._fallback_chat(messages, tools, stream, **kwargs) async def _fallback_chat( @@ -197,8 +206,9 @@ async def _fallback_chat( messages: list[dict], tools: list[dict] | None = None, stream: bool = False, + return_raw: bool = False, **kwargs - ) -> str | AsyncIterator[str]: + ) -> str | dict | AsyncIterator[dict]: logger.warning("[LLM] Starting fallback chat...") provider_names = list(self.providers.keys()) if self.default_provider in self.providers: @@ -211,7 +221,7 @@ async def _fallback_chat( try: provider = self.providers[name] start_time = time.time() - result = await provider.chat(messages, tools, stream, **kwargs) + result = await provider.chat(messages, tools, stream, return_raw=return_raw, **kwargs) elapsed = time.time() - start_time logger.success(f"[LLM] Fallback success: provider={name}, elapsed={elapsed:.2f}s") return result @@ -229,7 +239,7 @@ async def chat_stream( tools: list[dict] | None = None, provider_name: str | None = None, **kwargs - ) -> AsyncIterator[str]: + ) -> AsyncIterator[dict]: provider = self.get_provider(provider_name) actual_provider = provider_name or self.default_provider model = kwargs.get("model") or provider.default_model diff --git a/backend/app/runtime/provider/llm/providers.py b/backend/app/runtime/provider/llm/providers.py index 2de5cc8..9cd102d 100644 --- a/backend/app/runtime/provider/llm/providers.py +++ b/backend/app/runtime/provider/llm/providers.py @@ -1,3 +1,4 @@ +import re from typing import AsyncIterator from app.runtime.provider.base import LLMProvider from loguru import logger @@ -5,6 +6,64 @@ import json +def _clean_reasoning_content(raw_reasoning: str) -> str: + """清理推理内容,去除模型名称、重复文本等噪声 + + Ollama 等本地模型可能在 reasoning 字段中返回模型标识符或元数据, + 此函数用于过滤这些非推理内容,确保只保留真正的思考过程。 + + 处理的场景: + - 纯模型名重复:qwen3-vl:8bqwen3-vl:8bqwen3-vl:8b... + - 模型名片段:vl:8bqwen3-vl:8b... + - 行首/行尾的模型标识符 + """ + if not raw_reasoning: + return "" + + text = raw_reasoning.strip() + + # 场景1:检测连续重复的模型名称模式(最常见的问题) + # 匹配类似 "qwen3-vl:8b" 或 "llama-3.1-8b" 的模式重复 + model_name_pattern = r'[a-zA-Z0-9]+(?:-[a-zA-Z0-9.]+)*:[a-zA-Z0-9._-]+' + + # 检查是否整个文本主要由重复的模型名组成 + matches = re.findall(model_name_pattern, text) + if matches: + # 计算模型名占总文本的比例 + total_model_chars = sum(len(m) for m in matches) + ratio = total_model_chars / len(text) if text else 0 + + # 如果超过60%的字符都是模型名,认为是噪声 + if ratio > 0.6 and len(text) > 10: + logger.debug(f"[Provider] Filtered reasoning noise: model_name_ratio={ratio:.2f}, " + f"text_length={len(text)}") + return "" + + # 如果有多个相同的模型名重复出现(>=3次),也是噪声 + from collections import Counter + model_counts = Counter(matches) + most_common_model, count = model_counts.most_common(1)[0] if model_counts else ("", 0) + if count >= 3 and len(most_common_model) >= 5: + logger.debug(f"[Provider] Filtered repeated model name: '{most_common_model}' x{count}") + return "" + + # 场景2:移除行首/行尾的模型名(保留中间的有效内容) + # 行首模型名 + text = re.sub(r'^[a-zA-Z0-9_-]+:[a-zA-Z0-9._-]+\s*', '', text) + # 行尾模型名 + text = re.sub(r'\s*[a-zA-Z0-9_-]+:[a-zA-Z0-9._-]+$', '', text) + + # 场景3:移除孤立的模型名片段(如 "vl:8b" 前后没有其他有意义的内容) + # 如果清理后内容太短且看起来像片段,直接清空 + if len(text.strip()) < 8: + # 检查是否还包含模型名特征 + if re.search(r':[a-zA-Z0-9._-]', text): + logger.debug(f"[Provider] Filtered short fragment: length={len(text)}") + return "" + + return text.strip() + + PROVIDER_TEMPLATES = { "openai": { "id": "openai", @@ -138,7 +197,7 @@ "vendor": "ollama", "base_url": "http://localhost:11434/v1", "api_key": "ollama", - "default_model": "qwen2.5:7b", + "default_model": "qwen3-vl:8b", "description": "Local Ollama inference engine", }, "lmstudio": { @@ -171,6 +230,34 @@ } +_MODEL_CAPABILITIES = { + "gpt-4o": {"tool_calls": True, "multimodal": True}, + "gpt-4o-mini": {"tool_calls": True, "multimodal": True}, + "gpt-4-turbo": {"tool_calls": True, "multimodal": True}, + "gpt-4-": {"tool_calls": True, "multimodal": False}, + "o1": {"tool_calls": True, "multimodal": True}, + "o3-mini": {"tool_calls": True, "multimodal": False}, + "o3": {"tool_calls": True, "multimodal": False}, + "deepseek-chat": {"tool_calls": True, "multimodal": False}, + "deepseek-reasoner": {"tool_calls": False, "multimodal": False}, + "gemini": {"tool_calls": True, "multimodal": True}, + "mistral": {"tool_calls": True, "multimodal": False}, + "codestral": {"tool_calls": False, "multimodal": False}, + "llama-3.3-70b": {"tool_calls": True, "multimodal": False}, + "llama-3.1-": {"tool_calls": True, "multimodal": False}, + "grok": {"tool_calls": True, "multimodal": False}, + "moonshot-v1": {"tool_calls": True, "multimodal": False}, + "glm-4": {"tool_calls": True, "multimodal": True}, + "qwen-plus": {"tool_calls": True, "multimodal": False}, + "qwen-turbo": {"tool_calls": True, "multimodal": False}, + "qwen-max": {"tool_calls": True, "multimodal": False}, + "qwen2.5": {"tool_calls": True, "multimodal": False}, + "qwen3": {"tool_calls": True, "multimodal": False}, + "qwen2-vl": {"tool_calls": True, "multimodal": True}, + "qwen3-vl": {"tool_calls": True, "multimodal": True}, +} + + class OpenAICompatibleProvider(LLMProvider): provider_name = "openai_compatible" @@ -180,19 +267,55 @@ def __init__( base_url: str = "https://api.openai.com/v1", default_model: str = "gpt-4o-mini", provider_name: str = "openai_compatible", + force_enable_tool_calls: bool | None = None, ): self.api_key = api_key self.base_url = base_url.rstrip("/") self.default_model = default_model self.provider_name = provider_name + self.force_enable_tool_calls = force_enable_tool_calls + + def _lookup_capability(self, model: str, cap_key: str) -> bool | None: + if not model: + return None + model_lower = model.lower() + for key, caps in _MODEL_CAPABILITIES.items(): + if key in model_lower: + return caps.get(cap_key) + return None + + def supports_tool_calls(self, model: str = "") -> bool: + if self.force_enable_tool_calls is not None: + return self.force_enable_tool_calls + result = self._lookup_capability(model or self.default_model, "tool_calls") + if result is not None: + return result + if self.provider_name == "ollama": + return False + return True + + def supports_multimodal(self, model: str = "") -> bool: + result = self._lookup_capability(model or self.default_model, "multimodal") + if result is not None: + return result + return False async def chat( self, messages: list[dict], tools: list[dict] | None = None, stream: bool = False, + return_raw: bool = False, **kwargs - ) -> str | AsyncIterator[str]: + ) -> str | dict | AsyncIterator[dict]: + """调用大模型聊天接口 + + 参数: + messages: 对话消息列表 + tools: OpenAI Function Calling 格式工具定义列表 + stream: 是否使用流式响应 + return_raw: 是否返回完整 API 响应(含 tool_calls / reasoning),默认 False 仅返回文本 + """ if stream: return self.chat_stream(messages, tools, **kwargs) @@ -205,6 +328,18 @@ async def chat( ) resp.raise_for_status() data = resp.json() + if return_raw: + message = data.get("choices", [{}])[0].get("message", {}) + tool_calls = message.get("tool_calls", []) + raw_reasoning = message.get("reasoning", "") or message.get("reasoning_content", "") + # 清理推理内容 + reasoning = _clean_reasoning_content(raw_reasoning) + return { + "content": message.get("content", ""), + "reasoning": reasoning, + "tool_calls": tool_calls, + "role": message.get("role", "assistant"), + } return data["choices"][0]["message"]["content"] async def chat_stream( @@ -212,8 +347,9 @@ async def chat_stream( messages: list[dict], tools: list[dict] | None = None, **kwargs - ) -> AsyncIterator[str]: + ) -> AsyncIterator[dict]: payload = self._build_payload(messages, tools, stream=True, **kwargs) + collected_tool_calls: dict[int, dict] = {} async with httpx.AsyncClient(timeout=180.0) as client: async with client.stream( "POST", @@ -230,13 +366,47 @@ async def chat_stream( break try: data = json.loads(data_str) - delta = data.get("choices", [{}])[0].get("delta", {}) + choice = data.get("choices", [{}])[0] + delta = choice.get("delta", {}) content = delta.get("content", "") - if content: - yield content + raw_reasoning = delta.get("reasoning", "") or delta.get("reasoning_content", "") + + reasoning = _clean_reasoning_content(raw_reasoning) + + tool_calls_delta = delta.get("tool_calls") + if tool_calls_delta: + for tc in tool_calls_delta: + idx = tc.get("index", 0) + if idx not in collected_tool_calls: + collected_tool_calls[idx] = {"id": "", "name": "", "arguments": ""} + if tc.get("id"): + collected_tool_calls[idx]["id"] = tc["id"] + fn = tc.get("function", {}) + if fn.get("name"): + collected_tool_calls[idx]["name"] = fn["name"] + if fn.get("arguments"): + collected_tool_calls[idx]["arguments"] += fn["arguments"] + + result = {"content": content, "reasoning": reasoning} + if content or reasoning: + yield result except json.JSONDecodeError: continue + if collected_tool_calls: + merged = [] + for idx in sorted(collected_tool_calls.keys()): + entry = collected_tool_calls[idx] + merged.append({ + "id": entry["id"] or f"call_{idx}", + "type": "function", + "function": { + "name": entry["name"], + "arguments": entry["arguments"], + } + }) + yield {"content": "", "reasoning": "", "tool_calls_complete": merged} + async def embed(self, text: str) -> list[float]: async with httpx.AsyncClient(timeout=30.0) as client: resp = await client.post( diff --git a/backend/app/schemas/chat.py b/backend/app/schemas/chat.py index 0c4a0ad..3ae38fa 100644 --- a/backend/app/schemas/chat.py +++ b/backend/app/schemas/chat.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, Field -from typing import Any +from typing import Any, Literal class ChatMessageCreate(BaseModel): @@ -16,6 +16,15 @@ class ChatRequest(BaseModel): top_p: float | None = None stream: bool = False agent_id: str | None = None + timestamp: float | None = None + file_content: str | None = Field(default=None, max_length=100_000_000) + file_name: str | None = Field(default=None, max_length=255) + file_type: Literal[ + "text", "image", + "text/plain", "image/png", "image/jpeg", "image/gif", "image/webp", + "application/pdf", "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ] | None = None + search_results: str | None = Field(default=None, max_length=100_000) class ChatResponse(BaseModel): @@ -29,6 +38,7 @@ class ChatResponse(BaseModel): class ChatStreamChunk(BaseModel): id: str content: str + reasoning_content: str = "" model: str provider: str done: bool = False diff --git a/backend/app/schemas/common.py b/backend/app/schemas/common.py deleted file mode 100644 index 7a23971..0000000 --- a/backend/app/schemas/common.py +++ /dev/null @@ -1,14 +0,0 @@ -from pydantic import BaseModel -from typing import Any - - -class ApiResponse(BaseModel): - error: dict[str, str] | None = None - data: Any = None - - -class PaginatedResponse(BaseModel): - items: list[Any] - total: int - page: int = 1 - page_size: int = 20 diff --git a/backend/app/schemas/device.py b/backend/app/schemas/device.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/schemas/plugin.py b/backend/app/schemas/plugin.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/schemas/user.py b/backend/app/schemas/user.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/utils/intent_gateway.py b/backend/app/utils/intent_gateway.py new file mode 100644 index 0000000..1de2e5a --- /dev/null +++ b/backend/app/utils/intent_gateway.py @@ -0,0 +1,552 @@ +""" +前置意图网关 - 三级规则分类模块 + +功能: + 对用户输入消息进行轻量分类,将请求分为三类: + - LOCAL_TOOL:本地可处理的请求(时间/日期/星期/计算) + - TOOL_CALL:需要调用外部工具的请求(天气/搜索/旅游/行程/实时数据) + - GENERAL_CHAT:通用对话,其余所有请求的默认分类 + +三级分类流程(纯规则,零延迟,不用大模型): + 第一层:关键词粗匹配 —— 正则快速抓出所有含时间/日期关键词的候选 + 第二层:轻量级规则二次过滤 —— 否定词、句式结构、长度限制三重过滤 + 第三层:边缘情况兜底 —— 拿不准的直接走外接 API,不影响体验 + +增强特性: + - 集成八维度搜索意图评分器,识别隐式搜索需求 + - 支持对话历史上下文感知 + - 扩展工具调用关键词覆盖更多场景 + - 新增实时数据/知识边界/实体识别等隐式搜索触发 + +设计原则: + 1. 纯正则 + 规则树,零 IO、零网络、零大模型调用 + 2. 分类优先级:本地工具 > 工具调用 > 通用对话 + 3. 输入清洗:去除空格、中英文问号后匹配,避免格式干扰 + 4. 边界安全:空消息、纯标点消息默认返回 GENERAL_CHAT + 5. 保守策略:宁可漏判真查询让大模型兜底,也不能误判假查询 +""" + +import re +from enum import Enum + + +class RequestType(Enum): + """请求类型枚举,对应三种分流目标""" + LOCAL_TOOL = "local_tool" # 本地工具:时间/日期/星期/计算 + TOOL_CALL = "tool_call" # 工具调用:天气/搜索/旅游/行程/实时数据 + GENERAL_CHAT = "general_chat" # 通用对话:其余所有请求 + + +class IntentGateway: + """意图网关 —— 三级规则分类引擎 + + 用法: + gateway = IntentGateway() + result = gateway.classify("现在几点了") # RequestType.LOCAL_TOOL + result = gateway.classify("几点开会还没定") # RequestType.GENERAL_CHAT + """ + + def __init__(self): + # ================================================================ + # 第一层:关键词正则(粗匹配,先抓所有候选) + # ================================================================ + # 时间/日期关键词正则 + self.time_date_pattern = re.compile( + r"几点|几时|几号|几月几|日期|周几|星期几|礼拜几|几月|哪一天|" + r"什么时间|什么日期|当前时间|现在时间|看时间|报时|几月份|啥时候|" + r"农历|阴历|初一|十五|" + r"什么日子|什么节日|什么节|法定节假日|节假日|节日|过什么节|放不放假|" + r"今天几|明天几|后天几|昨天几|" + r"下周|上周|下个礼拜|上个礼拜|" + r"\d+天后|\d+天前|" + r"\d+[个]*(?:小时|分钟|天|钟头)[后前]|" + r"[一二三四五六七八九十]+[个]*(?:小时|分钟|天|钟头)[后前]|" + r"过\d+(?:小时|分钟|天)|" + r"时区|GMT|UTC|时差|" + r"[的地]时间(?!点|分|钟|段|候|长|差)", + ) + + # 计算类正则:数字运算符 或 计算意图词 + self.calc_pattern = re.compile( + r"\d+\s*[\+\-\*×xX÷/]\s*\d+" + r"|" + r"\d+\s*(?:加|減|减|乘|除|乘以|除以)\s*\d+" + r"|" + r"(计算|算一下|帮我算|等于多少|等于几|得多少|得几|是多少|答案是)" + ) + + # 天气关键词正则(粗匹配,所有含天气关键词的候选) + self.weather_pattern = re.compile( + r"天气|气温|温度|降水|降雨|下雪|下雨|湿度|风力|风向|" + r"空气质量|PM2\.5|pm2\.5|PM2|雾霾|预报|冷不冷|" + r"紫外线|带伞|穿衣|防晒|晴|多云|阴天|刮风|台风", + re.IGNORECASE, + ) + + # ================================================================ + # 第二层:轻量级规则配置(核心防误判) + # ================================================================ + + # 明确查询词 —— 命中任一即确认真查询(时间用) + self.query_words = { + "请问", "帮我查", "告诉我", "问一下", "查一下", "现在是", "今天是", + "帮忙看下", "麻烦告诉", "我想知道", "帮我看看", "问下", "请教", + "请告诉我", "帮忙查", "帮我问", "查查", + } + + # 否定词过滤列表 —— 命中任一即确认为假查询(时间/日期用) + self.negation_words = { + "不知道", "不确定", "没定", "还没", "忘了", "不记得", "没想好", + "不告诉你", "记不清", "记不得", "搞不清", "搞不懂", "没注意", + "不清楚", "不晓得", "没记住", "想不起", "说不上来", "忘记了", + "无从知晓", "搞不明白", "弄不明白", "弄不清", "说不好", + } + + # 假查询动词("几点+动词" 结构,无查询词 → 假查询) + self.fake_verbs = { + "出门", "开会", "吃饭", "睡觉", "上班", "下班", "约会", + "见面", "出发", "到达", "集合", "开始", "结束", "面试", + "上课", "下课", "放学", "起飞", "降落", "登机", "登车", + "开门", "关门", "打烊", "签到", "签退", "训练", "彩排", + "直播", "答辩", "考试", "复试", "笔试", "交班", "接班", + } + + # ----- 天气专属规则配置 ----- + + # 天气明确查询词白名单 + self.weather_query_whitelist: set[str] = { + "请问", "帮我查", "查一下", "告诉我", "问一下", "怎么样", + "如何", "多少", "帮我看", "麻烦", "请帮", "我想知道", + } + + # 天气非查询场景黑名单 + self.weather_negation_blacklist: set[str] = { + "不好", "不错", "太热", "太冷", "下雨了", "下雪了", + "上次", "之前", "的时候", "受不了", "烦", "讨厌", + } + + # 天气陈述动词模式 + self.weather_statement_verbs: set[str] = { + "不好", "不错", "热了", "变了", "冷了", "暖和", + "太差", "影响", "耽误", "坏了", + } + + # ----- 事件日期检测关键词 ----- + # 当消息同时包含时间关键词和这些事件关键词时, + # 本地时间工具无法回答,应走 TOOL_CALL(搜索工具) + self.event_date_keywords: set[str] = { + "软考", "考研", "高考", "中考", "国考", "省考", + "考公", "公务员", "事业编", "选调", "教资", "法考", + "注会", "一建", "二建", "复试", "笔试", "面试", + "报名", "准考证", "成绩", "录取", "分数线", + "世界杯", "奥运会", "亚运会", "世博会", "欧冠", + "NBA", "欧洲杯", "亚洲杯", "全运会", + "上映", "开售", "预售", "发售", "发布", + "开学", "放假", "开学季", "毕业", + "春运", "假期", "调休", + "发布会", "发布会", "直播", + } + + # ================================================================ + # 工具调用关键词(搜索+旅游+实时数据,天气已被 classify() 接管) + # ================================================================ + self._tool_keywords_search = _TOOL_KEYWORDS_SEARCH + self._tool_keywords_travel = _TOOL_KEYWORDS_TRAVEL + self._tool_keywords_realtime = _TOOL_KEYWORDS_REALTIME + self._tool_keywords_knowledge_boundary = _TOOL_KEYWORDS_KNOWLEDGE_BOUNDARY + self._tool_keywords_comparison = _TOOL_KEYWORDS_COMPARISON + self._tool_keywords_fact_specific = _TOOL_KEYWORDS_FACT_SPECIFIC + + def classify( + self, + user_message: str, + conversation_history: list[dict] | None = None, + ) -> RequestType: + """三层分类,返回 RequestType 枚举值 + + 支持返回 LOCAL_TOOL(本地工具直接处理)、TOOL_CALL(工具调用)、 + GENERAL_CHAT(通用对话)三种类型。 + + 纯规则实现,零延迟,不调大模型。 + + 参数: + user_message: 用户原始消息文本 + conversation_history: 对话历史(可选,用于上下文感知) + + 返回: + RequestType 枚举值 + """ + # 0. 边界与清洗 + if user_message is None: + return RequestType.GENERAL_CHAT + + original_msg = user_message.strip() + + # 清洗:去问号、去空格,用于关键词匹配 + clean_msg = original_msg.replace("?", "").replace("?", "").replace(" ", "").replace(" ", "") + + # 空消息或纯标点 → 通用对话 + if not clean_msg: + return RequestType.GENERAL_CHAT + if not re.sub(r"[\s\.,!!。,、;;::·~`@#$%^&*()()\[\]【】{}/\\|'\"<>《》\-_=+]+", "", clean_msg): + return RequestType.GENERAL_CHAT + + # ================================================================ + # 第一层:关键词粗匹配 + # ================================================================ + has_weather_keyword = bool(self.weather_pattern.search(clean_msg)) + has_time_keyword = bool(self.time_date_pattern.search(clean_msg)) + has_calc_keyword = bool(self.calc_pattern.search(clean_msg)) + + # ---- 天气检测(最优先!天气有自己的规则体系)---- + if has_weather_keyword: + return self._classify_weather(original_msg, clean_msg) + + # 纯计算请求(含运算符但无时间词)→ 本地工具,不走防误判 + if has_calc_keyword and not has_time_keyword: + return RequestType.LOCAL_TOOL + + # 无任何匹配 → 通用对话 + if not has_time_keyword and not has_calc_keyword: + return RequestType.GENERAL_CHAT + + # ================================================================ + # 第二层:轻量级规则二次过滤(时间/日期,核心!过滤 99% 误判) + # ================================================================ + + # 规则一:否定词 → 假查询 + if any(word in original_msg for word in self.negation_words): + return RequestType.GENERAL_CHAT + + # 规则二:明确查询词 → 真查询 + if any(word in original_msg for word in self.query_words): + return RequestType.LOCAL_TOOL + + # 规则二点五:事件日期检测 → TOOL_CALL + # 当消息同时包含时间关键词和特定事件/考试关键词时, + # 本地时间工具无法回答,应走搜索工具 + # 如 "河北软考几号"、"考研什么时候"、"国考几号" + if any(kw in clean_msg for kw in self.event_date_keywords): + return RequestType.TOOL_CALL + + # 规则三:"几点+动词" / "几号+动词" 结构 → 假查询 + for verb in self.fake_verbs: + test_str = clean_msg.lower() + if f"几点{verb}" in test_str or f"几号{verb}" in test_str: + return RequestType.GENERAL_CHAT + + # 规则四:短句(≤10字)且时间词在首/尾 → 真查询 + if len(original_msg) <= 10: + msg_start = original_msg[:5] + msg_end = original_msg[-5:] + if self.time_date_pattern.search(msg_start) or self.time_date_pattern.search(msg_end): + return RequestType.LOCAL_TOOL + + # 规则五:长句(>20字)且非明确查询 → 假查询 + if len(original_msg) > 20: + return RequestType.GENERAL_CHAT + + # ================================================================ + # 第三层:边缘情况走 API 兜底 + # ================================================================ + return RequestType.GENERAL_CHAT + + # ------------------------------------------------------------------ + # 天气分类子方法 + # ------------------------------------------------------------------ + + def _classify_weather(self, original_msg: str, clean_msg: str) -> RequestType: + """天气专用分类 —— 四层规则精准区分真查询 vs 假查询 + + 参数: + original_msg: 用户原始消息 + clean_msg: 清洗后的消息 + + 返回: + TOOL_CALL:真实天气查询 + GENERAL_CHAT:非天气查询 + """ + # R_W1: 天气明确查询词 → 真查询 + if any(word in original_msg for word in self.weather_query_whitelist): + return RequestType.TOOL_CALL + + # R_W2: 天气否定/陈述词且无查询词 → 假查询 + if any(word in original_msg for word in self.weather_negation_blacklist): + return RequestType.GENERAL_CHAT + + # R_W3: "天气+陈述动词"模式 → 假查询 + for verb in self.weather_statement_verbs: + if f"天气{verb}" in clean_msg: + return RequestType.GENERAL_CHAT + + # R_W4: 短句(≤10字)含天气词 → 真查询 + if len(original_msg) <= 10: + return RequestType.TOOL_CALL + + # R_W5: 长句(>25字)且无明确查询词 → 假查询 + if len(original_msg) > 25: + return RequestType.GENERAL_CHAT + + # 兜底:含天气关键词且未被过滤 → 真查询 + return RequestType.TOOL_CALL + + +# ============================================================================= +# 工具调用关键词集合 —— 覆盖搜索/旅游/实时数据/知识边界/比较/事实特异性 +# ============================================================================= + +_TOOL_KEYWORDS_SEARCH = { + "搜索", "查找", "搜一下", "查一下", "帮我搜", "帮我查", + "百度", "谷歌", "google", "百度一下", "搜一搜", + "帮我找", "帮我看看", "查资料", "检索", "搜寻", +} + +_TOOL_KEYWORDS_TRAVEL = { + "旅游", "旅行", "度假", "景点", "攻略", "游记", + "行程", "路线", "导航", "怎么去", "怎么走", "如何去", + "酒店", "民宿", "机票", "火车票", "订票", "订酒店", + "规划", "安排行程", "出行计划", "自驾", "跟团", + "周边游", "一日游", "几日游", "自由行", "签证", +} + +_TOOL_KEYWORDS_COUNTDOWN = { + "距离", "还有几天", "还剩几天", "剩下几天", "还有多久", + "是哪天", "是几号", "考试时间", "什么时候考试", "什么时候报名", +} + +_TOOL_KEYWORDS_REALTIME = { + "股价", "股票", "行情", "汇率", "油价", "金价", "房价", + "限行", "限号", "停水", "停电", "快递", "物流", + "招聘", "求职", "签证", "出入境", "入境政策", + "油价", "汽油价", "黄金价", "二手房", "均价", +} + +_TOOL_KEYWORDS_KNOWLEDGE_BOUNDARY = { + "最新", "当前", "目前", "刚刚", "刚才", + "2025年", "2026年", "2027年", "2028年", "2029年", +} + +_TOOL_KEYWORDS_COMPARISON = { + "哪个好", "怎么选", "对比", "比较", "区别", "差异", + "性价比", "划算", "值得买", "买哪个", "选哪个", + "排行", "排名", "榜单", "口碑", "评测", "测评", +} + +_TOOL_KEYWORDS_FACT_SPECIFIC = { + "分数线", "录取线", "报名费", "学费", "票价", "门票", + "营业时间", "开放时间", "官网", "下载地址", + "名额", "招生人数", "招聘人数", +} + + +def _is_tool_call_request(cleaned: str) -> bool: + """检查消息是否命中任一工具调用关键词集合""" + keyword_sets = [ + _TOOL_KEYWORDS_SEARCH, + _TOOL_KEYWORDS_TRAVEL, + _TOOL_KEYWORDS_COUNTDOWN, + _TOOL_KEYWORDS_REALTIME, + _TOOL_KEYWORDS_KNOWLEDGE_BOUNDARY, + _TOOL_KEYWORDS_COMPARISON, + _TOOL_KEYWORDS_FACT_SPECIFIC, + ] + return any(keyword in cleaned for keyword_set in keyword_sets for keyword in keyword_set) + + +# ============================================================================= +# 全局单例与对外接口 +# ============================================================================= + +_gateway = IntentGateway() + + +def classify_request( + user_message: str, + conversation_history: list[dict] | None = None, +) -> RequestType: + """核心分类函数 —— 对用户消息进行毫秒级意图分类 + + 四级分类流程: + 1. IntentGateway 判断 LOCAL_TOOL vs GENERAL_CHAT + 2. 关键词集合判断 TOOL_CALL(扩展覆盖实时数据/知识边界/比较/事实特异性) + 3. 搜索意图评分器判断隐式搜索需求(八维度评分) + 4. 兜底 GENERAL_CHAT + + 参数: + user_message: 用户输入的原始消息文本 + conversation_history: 对话历史(可选,用于上下文感知) + + 返回: + RequestType 枚举值,指示该请求的类型 + """ + # 第一步:本地工具 vs 通用对话(三级规则引擎) + result = _gateway.classify(user_message, conversation_history) + + # 第二步:如果三级引擎判定为 GENERAL_CHAT,再检查是否为工具调用 + if result == RequestType.GENERAL_CHAT: + clean_msg = ( + user_message.replace("?", "").replace("?", "") + .replace(" ", "").replace(" ", "") + ) + # 若消息中含有时间/日期关键词(被网关第二层规则过滤的), + # 说明整体语境是闲聊陈述而非信息查询,不应当触发工具调用。 + if _gateway.time_date_pattern.search(clean_msg): + return RequestType.GENERAL_CHAT + if _is_tool_call_request(clean_msg): + return RequestType.TOOL_CALL + + # 第三步:搜索意图评分器 —— 识别隐式搜索需求 + # 八维度评分:问题模式/实体时效/话题类别/否定信号/ + # 知识边界/实体识别/比较评价/事实特异性 + from app.utils.search_intent import needs_search + if needs_search(user_message, conversation_history): + return RequestType.TOOL_CALL + + return result + + +def is_weather_query(user_message: str) -> bool: + """辅助判断函数 —— 检查用户消息是否为真实天气查询 + + 参数: + user_message: 用户输入的原始消息文本 + + 返回: + True:真实天气查询 + False:非天气查询 + """ + if not user_message or not user_message.strip(): + return False + try: + result = _gateway.classify(user_message) + return result == RequestType.TOOL_CALL + except Exception: + return False + + +# ============================================================================= +# 直接运行验证(python -m app.utils.intent_gateway) +# ============================================================================= +if __name__ == "__main__": + test_cases = [ + # ===== 真时间查询 → LOCAL_TOOL ===== + ("现在几点了", RequestType.LOCAL_TOOL), + ("今天几号", RequestType.LOCAL_TOOL), + ("请问现在几点", RequestType.LOCAL_TOOL), + ("帮我查一下今天周几", RequestType.LOCAL_TOOL), + ("今天是星期几", RequestType.LOCAL_TOOL), + ("几点", RequestType.LOCAL_TOOL), + ("现在时间", RequestType.LOCAL_TOOL), + + # ===== 假时间查询 → GENERAL_CHAT ===== + ("我不知道今天几点出门", RequestType.GENERAL_CHAT), + ("几点开会还没定", RequestType.GENERAL_CHAT), + ("忘了今天是几号了", RequestType.GENERAL_CHAT), + ("几点吃饭", RequestType.GENERAL_CHAT), + ("明天几点集合", RequestType.GENERAL_CHAT), + ("不确定几点下班", RequestType.GENERAL_CHAT), + ("几点面试来着记不清了", RequestType.GENERAL_CHAT), + ("出门的时间几点了还不知道呢", RequestType.GENERAL_CHAT), + + # ===== 工具调用 → TOOL_CALL ===== + ("今天天气怎么样", RequestType.TOOL_CALL), + ("帮我搜索一下资料", RequestType.TOOL_CALL), + ("推荐一个旅游景点", RequestType.TOOL_CALL), + + # ===== 真天气查询 → TOOL_CALL ===== + ("北京天气怎么样", RequestType.TOOL_CALL), + ("明天会下雨吗", RequestType.TOOL_CALL), + ("请问今天气温多少", RequestType.TOOL_CALL), + ("帮我查一下上海明天的天气", RequestType.TOOL_CALL), + ("告诉我市区空气质量", RequestType.TOOL_CALL), + ("明天温度", RequestType.TOOL_CALL), + ("后天降水概率如何", RequestType.TOOL_CALL), + + # ===== 假天气查询 → GENERAL_CHAT ===== + ("今天天气不好不想出门", RequestType.GENERAL_CHAT), + ("天气不错适合出去玩", RequestType.GENERAL_CHAT), + ("今天太热了受不了", RequestType.GENERAL_CHAT), + ("上次下雨的时候我忘带伞了", RequestType.GENERAL_CHAT), + ("天气热了记得多喝水", RequestType.GENERAL_CHAT), + ("今天天气变化太大了烦死了", RequestType.GENERAL_CHAT), + ("之前下雪的时候拍的", RequestType.GENERAL_CHAT), + + # ===== 隐式搜索 → TOOL_CALL(八维度评分器触发)===== + ("2026年世界杯在哪举办", RequestType.TOOL_CALL), + ("iPhone 18什么时候出", RequestType.TOOL_CALL), + ("特斯拉股价多少", RequestType.TOOL_CALL), + ("最近有什么好看的电影", RequestType.TOOL_CALL), + ("河北软考几号", RequestType.TOOL_CALL), + ("今年考研什么时候报名", RequestType.TOOL_CALL), + ("NBA总决赛比分", RequestType.TOOL_CALL), + ("北京到上海的高铁时刻表", RequestType.TOOL_CALL), + + # ===== 知识边界 → TOOL_CALL ===== + ("2025年有什么新政策", RequestType.TOOL_CALL), + ("目前GPT-5出了吗", RequestType.TOOL_CALL), + ("最新版本的ChatGPT是什么", RequestType.TOOL_CALL), + + # ===== 实体识别 → TOOL_CALL ===== + ("GPT-5什么时候发布", RequestType.TOOL_CALL), + ("DeepSeek最新模型是什么", RequestType.TOOL_CALL), + ("Windows 12什么时候出", RequestType.TOOL_CALL), + + # ===== 比较评价 → TOOL_CALL ===== + ("iPhone 16和华为Mate70哪个好", RequestType.TOOL_CALL), + ("比亚迪和特斯拉怎么选", RequestType.TOOL_CALL), + + # ===== 事实特异性 → TOOL_CALL ===== + ("清华录取分数线多少", RequestType.TOOL_CALL), + ("北京故宫门票多少钱", RequestType.TOOL_CALL), + ("GPT-4官网下载地址", RequestType.TOOL_CALL), + + # ===== 实时数据 → TOOL_CALL ===== + ("今天油价多少", RequestType.TOOL_CALL), + ("黄金价格多少一克", RequestType.TOOL_CALL), + ("北京今天限行尾号", RequestType.TOOL_CALL), + ("美元汇率多少", RequestType.TOOL_CALL), + + # ===== 否定覆盖 → TOOL_CALL ===== + ("什么是GPT-5", RequestType.TOOL_CALL), + ("什么是2025年新规", RequestType.TOOL_CALL), + + # ===== 通用对话 → GENERAL_CHAT ===== + ("今天天气不错,几点吃饭?", RequestType.GENERAL_CHAT), + ("现在几点?不对,等一下", RequestType.GENERAL_CHAT), + ("给我写一段朋友圈文案", RequestType.GENERAL_CHAT), + ("你好,请介绍一下你自己", RequestType.GENERAL_CHAT), + ("", RequestType.GENERAL_CHAT), + ("???", RequestType.GENERAL_CHAT), + ("3+5等于多少", RequestType.LOCAL_TOOL), + ("Python怎么写快速排序", RequestType.GENERAL_CHAT), + ("什么是量子力学", RequestType.GENERAL_CHAT), + ("翻译一下这段话", RequestType.GENERAL_CHAT), + ("解释一下相对论", RequestType.GENERAL_CHAT), + ] + + print("=" * 80) + print(" IntentGateway 三级规则分类 测试结果(含八维度搜索意图)") + print("=" * 80) + print() + + passed = 0 + failed = 0 + + for msg, expected in test_cases: + display = msg if msg else "(空消息)" + result = classify_request(display) + if result == expected: + status = "PASS" + passed += 1 + else: + status = "FAIL" + failed += 1 + print(f" [{status}] {display:40} | 预期: {expected.value:14} | 实际: {result.value}") + + print() + print(f" 通过: {passed} 失败: {failed} 总计: {len(test_cases)}") + + if failed == 0: + print("\n 全部测试通过!") + else: + print(f"\n 有 {failed} 个测试未通过,需要检查规则配置") diff --git a/backend/app/utils/local_handler.py b/backend/app/utils/local_handler.py new file mode 100644 index 0000000..027dcc1 --- /dev/null +++ b/backend/app/utils/local_handler.py @@ -0,0 +1,220 @@ +""" +本地请求处理器 - 本地工具请求的统一分发入口 + +功能: + 将用户消息分流到对应的本地处理工具,当前已接入: + - 时间工具(time_tool):毫秒级时间/日期/星期查询 + - 天气工具(weather_tool):毫秒级天气查询 + +职责: + 1. 调用 intent_gateway 进行请求分类 + 2. 命中 LOCAL_TOOL 后,调用对应工具生成回复 + 3. 命中天气请求时,提取城市后调用天气工具 + 4. 未命中则返回 None,不影响调用方继续走原有的对话流程 + 5. 全链路异常兜底,绝对不会中断主流程 + +设计原则: + 1. 纯分发逻辑,不包含任何业务计算 + 2. 返回 None 表示"此请求不属于本地处理范围",调用方可继续走大模型 + 3. 返回有效字符串表示"已本地处理完毕",调用方可直接使用回复 + 4. 极端异常也返回兜底话术,确保对话不中断 +""" + +import re + +from loguru import logger + +from app.utils.intent_gateway import classify_request, RequestType, is_weather_query +from app.utils.time_tool import TimeTool +from app.utils.weather_tool import _weather_tool + +# 模块级时间工具单例 —— 保持多轮对话状态跨请求持久化 +_time_tool_instance = TimeTool(timezone="Asia/Shanghai") + + +# ============================================================================= +# 城市名提取 —— 从用户消息中提取城市名称 +# ============================================================================= + +# 中国主要城市名正则 +_CITY_PATTERN = re.compile( + r"(北京|上海|广州|深圳|杭州|成都|武汉|西安|南京|重庆|天津|" + r"苏州|长沙|郑州|济南|青岛|大连|厦门|福州|昆明|贵阳|南宁|" + r"海口|三亚|哈尔滨|长春|沈阳|乌鲁木齐|拉萨|兰州|银川|西宁|" + r"呼和浩特|太原|石家庄|合肥|南昌|东莞|佛山|无锡|宁波|温州|" + r"徐州|珠海|惠州|中山|烟台|威海)" +) + +# 城市别名正则 —— 单字别名仅在独立出现时匹配(前后无中文字符) +_CITY_ALIAS_PATTERN = re.compile( + r"(? str | None: + """从用户消息中提取城市名称 + + 在消息中搜索匹配的城市名,返回第一个匹配的城市(完整名称)。 + 优先匹配完整城市名,再匹配独立出现的单字别名。 + + 参数: + user_message: 用户输入的原始消息文本 + + 返回: + 城市完整名称,未找到返回 None + """ + matched = _CITY_PATTERN.findall(user_message) + if matched: + city = matched[0] + return _CITY_ALIAS.get(city, city) + + alias_matched = _CITY_ALIAS_PATTERN.findall(user_message) + if alias_matched: + city = alias_matched[0] + return _CITY_ALIAS.get(city, city) + + return None + + +def _extract_date_from_message(user_message: str, city: str) -> str: + """从用户消息中提取日期部分,传给天气工具的 parse_query_date 解析 + + 处理流程: + 1. 去掉消息中的城市名称 + 2. 去掉常见的天气查询词(天气、气温、多少度等) + 3. 剩下的部分即为日期候选文本 + + 参数: + user_message: 用户输入的原始消息文本 + city: 已提取的城市名 + + 返回: + 日期候选文本,如"明天"、"5.1号"、"下周一",无日期时返回空字符串 + """ + clean = user_message + # 去掉城市名 + clean = clean.replace(city, "") + # 去掉常见查询后缀 + for phrase in [ + "天气怎么样", "天气如何", "天气怎样", "天气", + "气温", "温度", "多少度", "几度", "冷不冷", "热不热", + "怎么样", "如何", "怎样", "预报", "天气预报", + "穿衣", "带伞", "防晒", "的", "吗", "吧", "呢", "啊", "哦", + ]: + clean = clean.replace(phrase, "") + + return clean.strip() + + +# ============================================================================= +# 天气请求处理 +# ============================================================================= + +async def handle_weather_request(user_message: str) -> str | None: + """处理天气查询请求 —— 提取城市和日期,调用天气工具 + + 流程: + 1. 从消息中提取城市名称 + 2. 从消息中提取日期(如"明天"、"5.1号"、"下周一") + 3. 若提取城市失败 → 返回 None,让调用方继续走原有流程 + 4. 直接 await 天气工具的异步接口获取回复 + + 参数: + user_message: 用户输入的原始消息文本 + + 返回: + - 有效字符串:天气回复 + - None:无法提取城市,应由调用方继续处理 + + 用法: + reply = await handle_weather_request("北京明天天气怎么样") + if reply: + return reply # 已本地处理 + """ + try: + city = _extract_city(user_message) + if city is None: + # 无城市名 → 返回 None,让调用方走工具调用循环 + # (大模型可以从上下文推断城市) + return None + + # 提取日期:去掉城市名称部分后传入 parse_query_date + date_str = _extract_date_from_message(user_message, city) + + # 在异步上下文中直接 await,避免同步封装的事件循环冲突 + reply = await _weather_tool.get_reply(city, date_str) + return reply + + except Exception as e: + logger.warning(f"[LocalHandler] 处理天气请求异常: {e}") + return None + + +# ============================================================================= +# 统一分发入口 +# ============================================================================= + + +async def handle_local_tool_request(user_message: str) -> str | None: + """处理本地工具请求的入口函数 + + 流程: + 1. 调用 classify_request 判断请求类型 + 2. 若为 LOCAL_TOOL,调用时间工具生成回复 + 3. 若为 TOOL_CALL(天气),尝试提取城市并调用天气工具 + 4. 若为其他类型,返回 None 让调用方继续走大模型对话 + 5. 若发生异常,返回友好的兜底话术 + + 参数: + user_message: 用户输入的原始消息文本 + + 返回: + - 有效字符串:本地已处理完成,可直接用作对话回复 + - None:此请求不属于本地工具范围,调用方应继续走大模型 + + 用法: + reply = await handle_local_tool_request("现在几点?") + if reply is not None: + return reply + """ + try: + # 第一步:分类 + request_type = classify_request(user_message) + + # 第二步:时间查询(LOCAL_TOOL) + if request_type == RequestType.LOCAL_TOOL: + # 使用模块级单例,保持多轮对话状态跨请求 + # 走增强接口 get_reply_with_context,支持口语化/场景化/多轮对话 + # 时间偏移查询(1小时后几点)→ 由 get_reply_with_context 内部识别处理 + reply = _time_tool_instance.get_reply_with_context( + query_type="time", + user_message=user_message, + agent_type="通用", + ) + if reply: + return reply + return None + + # 第三步:天气查询(TOOL_CALL) + if request_type == RequestType.TOOL_CALL: + # 二次确认:is_weather_query 用专属规则树验证 + if is_weather_query(user_message): + reply = await handle_weather_request(user_message) + if reply: + return reply + # 不是天气的 TOOL_CALL(搜索/旅游等)→ 返回 None 走工具调用循环 + return None + + # 第四步:其他类型 → 返回 None + return None + + except Exception as e: + # 全链路兜底:任何异常都不中断对话,返回友好话术 + logger.warning(f"[LocalHandler] 处理本地工具请求异常: {e}") + return "抱歉,我暂时无法处理这个请求,您可以换种方式问我哦~" diff --git a/backend/app/utils/search_intent.py b/backend/app/utils/search_intent.py new file mode 100644 index 0000000..adb5ee3 --- /dev/null +++ b/backend/app/utils/search_intent.py @@ -0,0 +1,665 @@ +""" +搜索意图识别器 —— 多维度评分制判断用户是否需要联网搜索 + +设计思路: + 旧方案(纯关键词匹配)太片面,只能识别"搜索"、"查一下"等显式搜索词, + 无法覆盖大量隐式搜索需求,如: + - "2026年世界杯在哪举办"(地点查询) + - "iPhone 18什么时候出"(时间查询) + - "特斯拉股价多少"(实时数据) + - "最近有什么好看的电影"(时效性推荐) + - "河北软考几号"(具体事件日期) + + 新方案(八维度评分制)综合考虑以下信号: + 1. 问题模式层:疑问词+实体 → "X是几号"/"X什么时候"/"X在哪" + 2. 实体时效层:专有名词+时间词 → "2026年软考"/"最近新闻" + 3. 话题类别层:特定话题几乎必搜 → 股价/新闻/考试/赛事 + 4. 否定信号层:明确不需要搜索 → 数学/创作/编程/历史常识 + 5. 知识边界层:LLM 训练截止后的事实 → "2025年"/"最新政策" + 6. 实体识别层:专有名词+疑问结构 → "GPT-5什么时候出" + 7. 比较评价层:需要实时数据对比 → "X和Y哪个好" + 8. 事实特异性层:要求精确数字/日期/地点 → "X的录取分数线" + + 每个信号加权评分,总分超过阈值则触发搜索。 + + 上下文感知: + 支持传入对话历史,识别追问型搜索需求("那X呢"/"还有呢")。 +""" + +import re +from loguru import logger + + +# ============================================================================= +# 信号定义(正则 + 权重) +# ============================================================================= + +# --------------------------------------------------------------------------- +# 维度1:问题模式(疑问词 + 实体) +# --------------------------------------------------------------------------- +_POSITIVE_PATTERNS: list[tuple[re.Pattern, int, str]] = [ + (re.compile(r"什么时候"), 4, "时间疑问"), + (re.compile(r"是哪天|是几号|哪一天|几号"), 4, "日期疑问"), + (re.compile(r"在哪里|在哪|在哪举办|在哪个"), 3, "地点疑问"), + (re.compile(r"多少|多少钱|几钱|价格"), 3, "数值疑问"), + (re.compile(r"有没有|是否|会不会|能不能"), 1, "是非疑问"), + (re.compile(r"怎么样|如何|好不好"), 2, "评价疑问"), + + (re.compile(r"距离|离.*还有|还剩|还有几|还差几"), 4, "倒计时模式"), + + (re.compile(r"搜索|查找|搜一下|帮我搜|帮我查|查一下|查查|搜一搜"), 5, "显式搜索"), + + (re.compile(r"今年|明年|去年|本周|上周|最近|最新|当前|目前|现在|今日|昨日"), 2, "时效性词"), + (re.compile(r"\d{4}年"), 1, "年份引用"), + + (re.compile(r"股价|股票|行情|涨幅|跌幅|市值|基金|比特币|加密货币"), 5, "金融实时"), + (re.compile(r"新闻|热点|头条|爆料|事件|事故"), 4, "新闻热点"), + (re.compile(r"考试|报名|准考证|成绩|录取|分数线|软考|考研|高考|中考|国考"), 4, "考试信息"), + (re.compile(r"比赛|赛事|比分|积分|排名|赛程|对阵|世界杯|奥运会|欧冠|NBA"), 4, "体育赛事"), + (re.compile(r"上映|票房|评分|豆瓣|IMDb|排行|榜单|推荐.*电影|好看.*剧|好看.*电影|好看.*片|有什么好看|有什么.*推荐"), 3, "影视娱乐"), + (re.compile(r"航班|高铁|火车|机票|车次|时刻表|晚点"), 4, "交通出行"), + (re.compile(r"政策|法规|规定|新规|出台|实施|生效"), 3, "政策法规"), + (re.compile(r"发布|推出|上市|开售|预售|发售|新品"), 3, "产品发布"), + + (re.compile(r"[\u4e00-\u9fa5]{2,}(什么时候|是几号|是哪天|在哪|多少|怎么样)"), 3, "实体+疑问"), + + (re.compile(r"旅游|旅行|攻略|景点|酒店|民宿|机票|签证"), 3, "旅游出行"), + + (re.compile(r"iPhone|iPad|MacBook|华为|小米|三星|特斯拉|比亚迪|蔚来"), 2, "品牌产品"), + + (re.compile(r"推荐|好不好|值不值得|值得|怎么样"), 2, "推荐评价"), +] + +# --------------------------------------------------------------------------- +# 维度5:知识边界信号(LLM 训练截止后的事实,几乎必搜) +# --------------------------------------------------------------------------- +_KNOWLEDGE_BOUNDARY_PATTERNS: list[tuple[re.Pattern, int, str]] = [ + (re.compile(r"202[5-9]年"), 3, "近年引用"), + (re.compile(r"今年.*?(政策|规定|新规|考试|报名|分数线|录取|赛事|举办)"), 4, "今年+时效词"), + (re.compile(r"最新.*?(政策|规定|版本|消息|动态|公告|通知|发布)"), 4, "最新+时效词"), + (re.compile(r"当前.*?(状态|情况|进度|排名|价格|行情|政策)"), 4, "当前+状态词"), + (re.compile(r"目前.*?(支持|可用|开放|上线|发布|运行)"), 3, "目前+状态词"), + (re.compile(r"现在.*?(能不能|可不可以|是否可以|还来得及|还开不开)"), 3, "现在+可行性"), + (re.compile(r"刚刚|刚|才|刚刚才|刚才"), 1, "即时性词"), + (re.compile(r"有没有.*?(出|发|开|上|更新|修复|支持)"), 3, "有无更新"), + (re.compile(r"什么时候.*?(出|发|开|上|更新|修复|支持|上线|开放)"), 4, "何时更新"), + (re.compile(r"(什么是|什么叫|介绍下|介绍一下).{0,5}?(GPT|Claude|Gemini|DeepSeek|Kimi|Sora|Copilot|ChatGPT|OpenAI|Llama)"), 4, "AI概念查询"), + (re.compile(r"(什么是|什么叫|介绍下|介绍一下).{0,5}?(202[5-9]|最新|新出|新规)"), 4, "时敏概念查询"), +] + +# --------------------------------------------------------------------------- +# 维度6:实体识别信号(专有名词 + 疑问结构) +# --------------------------------------------------------------------------- +_ENTITY_PATTERNS: list[tuple[re.Pattern, int, str]] = [ + (re.compile(r"GPT-?\d|Claude|Gemini|Llama|通义|文心|千问|智谱|DeepSeek|Kimi|豆包"), 3, "AI产品"), + (re.compile(r"Windows\s*\d+|macOS|iOS\s*\d+|Android\s*\d+|HarmonyOS"), 3, "操作系统"), + (re.compile(r"Python\s*3\.\d+|Node\.js|React|Vue|Next\.js|Django|FastAPI"), 1, "编程框架"), + (re.compile(r"ChatGPT|OpenAI|Anthropic|Google|Meta|Microsoft|Apple|Nvidia"), 2, "科技公司"), + (re.compile(r"世博会|奥运会|世界杯|亚运会|冬奥会|欧洲杯|亚洲杯|全运会"), 4, "大型赛事"), + (re.compile(r"双十一|618|黑五|双十二|年货节|购物节"), 3, "购物节"), + (re.compile(r"考研|国考|省考|事业编|公务员|选调|教资|法考|注会|一建|二建"), 3, "考试名称"), + (re.compile(r"诺贝尔|奥斯卡|格莱美|金球奖|金鸡奖|百花奖|茅盾奖"), 3, "奖项名称"), + (re.compile(r"两会|人大|政协|党代会|中央全会|国务院"), 3, "政治事件"), + (re.compile(r"[\u4e00-\u9fa5]{2,4}(省|市|自治区)(的)?(政策|规定|补贴|落户|限购)"), 3, "地方政策"), + (re.compile(r"[\u4e00-\u9fa5]{2,6}(大学|学院|中学)(的)?(录取线|分数线|招生|排名)"), 3, "学校信息"), + (re.compile(r"[\u4e00-\u9fa5]{2,}(医院|诊所)(的)?(挂号|排班|专家|门诊)"), 2, "医疗信息"), + (re.compile(r"[\u4e00-\u9fa5]{2,}(地铁|公交|高铁|火车|航班)(的)?(时刻表|路线|班次|票价)"), 3, "交通信息"), +] + +# --------------------------------------------------------------------------- +# 维度7:比较评价信号(需要实时数据对比) +# --------------------------------------------------------------------------- +_COMPARISON_PATTERNS: list[tuple[re.Pattern, int, str]] = [ + (re.compile(r"和.{1,10}(哪个好|哪个值得|哪个好|哪个强|哪个便宜|怎么选|区别|对比|比较|不同|差异)"), 4, "比较选择"), + (re.compile(r"还是.{1,10}(好|值得|强|便宜|划算)"), 3, "还是选择"), + (re.compile(r"对比|比较|区别|差异|不同|优缺点|优劣"), 2, "对比词"), + (re.compile(r"性价比|划算|值得买|推荐买|买哪个|选哪个"), 3, "购买决策"), + (re.compile(r"排行|排名|Top|top|前十|前五|榜单|口碑"), 3, "排名推荐"), + (re.compile(r"评测|测评|体验|使用感受|真实评价"), 3, "评测体验"), +] + +# --------------------------------------------------------------------------- +# 维度8:事实特异性信号(要求精确数字/日期/地点) +# --------------------------------------------------------------------------- +_FACT_SPECIFICITY_PATTERNS: list[tuple[re.Pattern, int, str]] = [ + (re.compile(r"\d+月\d+[日号]"), 3, "具体日期"), + (re.compile(r"放假|放假安排|假期|调休|补班|补休|调课"), 4, "假期安排"), + (re.compile(r"录取线|分数线|合格线|及格线|最低分|最高分"), 4, "分数查询"), + (re.compile(r"报名费|学费|票价|门票|价格|多少钱|收费"), 3, "价格查询"), + (re.compile(r"营业时间|开放时间|上班时间|开门|关门|打烊"), 4, "时间查询"), + (re.compile(r"地址|在哪|位置|怎么走|怎么去|路线|导航"), 2, "地点查询"), + (re.compile(r"电话|联系方式|客服|咨询电话|预约"), 2, "联系方式"), + (re.compile(r"要求|条件|资格|门槛|限制|年龄限制|学历要求"), 2, "条件查询"), + (re.compile(r"流程|步骤|怎么办|如何办理|怎么申请|怎么操作"), 1, "流程查询"), + (re.compile(r"名额|招生人数|招聘人数|录取人数|限购"), 3, "名额查询"), + (re.compile(r"官网|官方网站|下载地址|下载链接|安装包"), 3, "官方资源"), +] + +# --------------------------------------------------------------------------- +# 维度3补充:实时数据信号(数据频繁变化的领域) +# --------------------------------------------------------------------------- +_REALTIME_DATA_PATTERNS: list[tuple[re.Pattern, int, str]] = [ + (re.compile(r"汇率|换汇|外汇|人民币汇率|美元汇率|欧元汇率"), 5, "汇率查询"), + (re.compile(r"油价|汽油价|柴油价|92号|95号|98号"), 5, "油价查询"), + (re.compile(r"金价|黄金价|银价|铂金价|金价走势"), 5, "贵金属价格"), + (re.compile(r"房价|二手房|新房|楼盘|均价|成交价"), 4, "房价查询"), + (re.compile(r"限行|限号|尾号限行|单双号"), 4, "限行查询"), + (re.compile(r"停水|停电|停气|检修|维修|维护通知"), 4, "民生通知"), + (re.compile(r"快递|物流|发货|到货|运费|邮费"), 2, "物流查询"), + (re.compile(r"招聘|求职|岗位|薪资|待遇|offer|面试结果"), 3, "招聘求职"), + (re.compile(r"疫苗|挂号|核酸检测|门诊|就诊|医保"), 3, "医疗健康"), + (re.compile(r"签证|护照|出入境|海关|入境政策"), 4, "出入境"), +] + +# --------------------------------------------------------------------------- +# 维度4:否定信号(明确不需要搜索的场景) +# --------------------------------------------------------------------------- +_NEGATIVE_PATTERNS: list[tuple[re.Pattern, int, str]] = [ + (re.compile(r"计算|算一下|等于多少|\d+\s*[\+\-\*×xX÷/]\s*\d+"), 3, "数学计算"), + (re.compile(r"写(一|个|段|篇)|帮我写|生成|创作|编一个|编个"), 3, "创作写作"), + (re.compile(r"代码|编程|python|java|javascript|函数|算法|bug|调试"), 2, "编程问题"), + (re.compile(r"翻译|translate"), 2, "翻译请求"), + (re.compile(r"你好|早上好|晚上好|晚安|谢谢|再见"), 3, "寒暄闲聊"), + (re.compile(r"^(今天|现在|当前)(几号|几点|几时|星期几|周几|什么时间|什么日期)$"), 5, "纯时间查询"), +] + +# --------------------------------------------------------------------------- +# 否定信号增强:时敏性常识例外("什么是X"但X是近期概念,仍需搜索) +# --------------------------------------------------------------------------- +_NEGATIVE_OVERRIDE_PATTERNS: list[tuple[re.Pattern, int, str]] = [ + (re.compile(r"什么是.*?(GPT|Claude|Gemini|Llama|Sora|Copilot|DeepSeek|Kimi|豆包|通义|文心)"), 4, "AI概念解释"), + (re.compile(r"什么是.*?(202[5-9]|最新|新出|新发|新规|新政)"), 4, "时敏概念解释"), + (re.compile(r"解释一下.*?(政策|规定|新规|新法|改革|调整)"), 3, "时敏政策解释"), + (re.compile(r"(什么是|什么叫|介绍下|介绍一下).{0,5}?(GPT|Claude|Gemini|DeepSeek|Kimi|Sora|Copilot|ChatGPT|OpenAI)"), 4, "AI产品解释"), +] + +# --------------------------------------------------------------------------- +# 上下文感知:追问型搜索需求 +# --------------------------------------------------------------------------- +_FOLLOWUP_SEARCH_PATTERNS: list[tuple[re.Pattern, int, str]] = [ + (re.compile(r"^那.{1,10}(呢|怎么样|什么时候|在哪|多少|有没有)$"), 3, "追问话题"), + (re.compile(r"^还有呢|还有吗|还有没有|除此之外"), 2, "追问补充"), + (re.compile(r"^具体(一点|来说|是)|详细(一点|说说|介绍)"), 1, "追问详情"), +] + +# 搜索阈值:总分 >= 此值则触发搜索 +_SEARCH_THRESHOLD = 4 + + +def _check_negative_override(clean_msg: str) -> int: + """检查否定信号的覆盖条件 + + 当消息同时命中否定信号和覆盖信号时,覆盖信号可以抵消否定减分。 + 例如 "什么是GPT-5" 虽然命中"什么是"否定信号, + 但 GPT-5 是近期概念,仍需搜索。 + + 参数: + clean_msg: 清洗后的消息 + + 返回: + 覆盖加分(抵消否定减分) + """ + bonus = 0 + for pattern, weight, label in _NEGATIVE_OVERRIDE_PATTERNS: + if pattern.search(clean_msg): + bonus += weight + return bonus + + +def compute_search_score( + user_message: str, + conversation_history: list[dict] | None = None, +) -> tuple[int, list[str]]: + """计算用户消息的搜索需求评分 + + 八维度评分: + 1. 问题模式层:疑问词+实体 + 2. 实体时效层:专有名词+时间词 + 3. 话题类别层:特定话题几乎必搜 + 4. 否定信号层:明确不需要搜索 + 5. 知识边界层:LLM 训练截止后的事实 + 6. 实体识别层:专有名词+疑问结构 + 7. 比较评价层:需要实时数据对比 + 8. 事实特异性层:要求精确数字/日期/地点 + + 参数: + user_message: 用户原始消息 + conversation_history: 对话历史(可选,用于上下文感知) + + 返回: + (总分, 命中信号列表) + """ + if not user_message or not user_message.strip(): + return (0, []) + + clean_msg = user_message.replace("?", "").replace("?", "").replace(" ", "") + score = 0 + signals: list[str] = [] + + # 维度1-3:基础正向信号 + for pattern, weight, label in _POSITIVE_PATTERNS: + if pattern.search(clean_msg): + score += weight + signals.append(f"+{weight} {label}") + + # 维度5:知识边界信号 + for pattern, weight, label in _KNOWLEDGE_BOUNDARY_PATTERNS: + if pattern.search(clean_msg): + score += weight + signals.append(f"+{weight} [KB]{label}") + + # 维度6:实体识别信号 + for pattern, weight, label in _ENTITY_PATTERNS: + if pattern.search(clean_msg): + score += weight + signals.append(f"+{weight} [ENT]{label}") + + # 维度7:比较评价信号 + for pattern, weight, label in _COMPARISON_PATTERNS: + if pattern.search(clean_msg): + score += weight + signals.append(f"+{weight} [CMP]{label}") + + # 维度8:事实特异性信号 + for pattern, weight, label in _FACT_SPECIFICITY_PATTERNS: + if pattern.search(clean_msg): + score += weight + signals.append(f"+{weight} [FACT]{label}") + + # 维度3补充:实时数据信号 + for pattern, weight, label in _REALTIME_DATA_PATTERNS: + if pattern.search(clean_msg): + score += weight + signals.append(f"+{weight} [RT]{label}") + + # 维度4:否定信号 + neg_total = 0 + for pattern, weight, label in _NEGATIVE_PATTERNS: + if pattern.search(clean_msg): + neg_total += weight + signals.append(f"-{weight} {label}") + + # 否定信号覆盖检查 + if neg_total > 0: + override_bonus = _check_negative_override(clean_msg) + if override_bonus > 0: + neg_total = max(0, neg_total - override_bonus) + signals.append(f"+{override_bonus} 否定覆盖") + + score -= neg_total + + # 上下文感知:追问型搜索需求 + if conversation_history: + followup_bonus = _check_followup_context(clean_msg, conversation_history) + if followup_bonus > 0: + score += followup_bonus + signals.append(f"+{followup_bonus} [CTX]追问搜索") + + # 追问模式(无需对话历史也能检测的简单模式) + for pattern, weight, label in _FOLLOWUP_SEARCH_PATTERNS: + if pattern.search(clean_msg): + score += weight + signals.append(f"+{weight} [FUP]{label}") + + return (score, signals) + + +def _check_followup_context( + clean_msg: str, + conversation_history: list[dict], +) -> int: + """检查对话历史中的追问型搜索需求 + + 当用户在搜索结果后追问相关话题时,追加搜索评分。 + 例如: + AI: "2026年软考时间是5月..." → 用户: "那报名条件是什么" + AI: "北京明天晴..." → 用户: "后天呢" + + 参数: + clean_msg: 清洗后的消息 + conversation_history: 对话历史 + + 返回: + 追问搜索加分 + """ + if not conversation_history or len(conversation_history) < 2: + return 0 + + last_assistant_msg = "" + for msg in reversed(conversation_history): + if msg.get("role") == "assistant": + last_assistant_msg = msg.get("content", "") + break + + if not last_assistant_msg: + return 0 + + assistant_search_indicators = [ + "搜索结果", "查询到", "根据搜索", "网上", "来源:", + "搜索显示", "查到", "检索到", "最新消息", "据报道", + ] + is_search_result = any(ind in last_assistant_msg for ind in assistant_search_indicators) + + if not is_search_result: + return 0 + + followup_patterns = [ + r"^那.{1,10}(呢|怎么样|什么时候|在哪|多少|有没有)", + r"^还有呢|还有吗|还有没有", + r"^具体(一点|来说|是)", + ] + for pat in followup_patterns: + if re.search(pat, clean_msg): + return 3 + + return 0 + + +def needs_search( + user_message: str, + conversation_history: list[dict] | None = None, +) -> bool: + """判断用户消息是否需要联网搜索 + + 参数: + user_message: 用户原始消息 + conversation_history: 对话历史(可选,用于上下文感知) + + 返回: + True: 需要搜索 + False: 不需要搜索 + """ + score, signals = compute_search_score(user_message, conversation_history) + result = score >= _SEARCH_THRESHOLD + if result: + logger.info(f"[SearchIntent] 搜索意图命中: score={score}, signals={signals}") + return result + + +def extract_search_query(user_message: str) -> str: + """从用户消息中提取适合搜索的查询词 + + 清洗策略: + 1. 去除寒暄词 + 2. 去除倒计时词("距离"/"还有几天") + 3. 去除搜索前缀词("帮我搜"/"查一下") + 4. 去除比较选择词("和X哪个好") + 5. 添加时间上下文(当前年份、上半年/下半年推断) + 6. 如果清洗后为空,返回原始消息 + + 时间上下文增强: + - 当查询涉及考试/赛事等周期性事件时,自动推断当前应搜索的时间范围 + - 例如5月问"软考"→ 上半年已过 → 搜索"2026年下半年软考时间" + - 例如1月问"软考"→ 上半年未到 → 搜索"2026年上半年软考时间" + + 参数: + user_message: 用户原始消息 + + 返回: + 清洗后并添加时间上下文的搜索查询词 + """ + query = user_message.strip() + + query = re.sub( + r"^(搜索|查找|搜一下|帮我搜|帮我查|查一下|查查|搜一搜|帮我搜索|帮我查找|帮我找|帮我找找|请问|麻烦问|问一下)\s*", + "", query + ).strip() + + query = re.sub(r"(距离|离)\s*", "", query).strip() + query = re.sub(r"(还有几天|还剩几天|剩下几天|还有多久|还差几天)\s*$", "", query).strip() + + query = re.sub(r"^(请问|麻烦|帮忙|你好|您好)\s*", "", query).strip() + + query = re.sub(r"(和.{1,10}哪个好|和.{1,10}怎么选|还是.{1,5}好)\s*$", "", query).strip() + + if not query: + query = user_message.strip() + + query = _enrich_query_with_time_context(query) + + return query + + +def _enrich_query_with_time_context(query: str) -> str: + """为搜索查询添加时间上下文 + + 核心逻辑: + 1. 检测查询中是否包含周期性事件关键词(考试/赛事/节日等) + 2. 如果包含且查询中没有明确年份/上下半年 → 自动补充 + 3. 推断规则: + - 1-4月 → 搜索"上半年"(上半年考试通常5-6月举行) + - 5-8月 → 搜索"下半年"(上半年已过,下半年通常11月举行) + - 9-12月 → 搜索"下半年"(下半年考试通常11月举行) + + 参数: + query: 清洗后的搜索查询词 + + 返回: + 添加时间上下文后的搜索查询词 + """ + has_year = bool(re.search(r"20\d{2}年", query)) + has_half = bool(re.search(r"上半年|下半年", query)) + + if has_year and has_half: + return query + + periodic_event_patterns = [ + (re.compile(r"软考"), "考试"), + (re.compile(r"考研"), "考试"), + (re.compile(r"高考"), "考试"), + (re.compile(r"中考"), "考试"), + (re.compile(r"国考"), "考试"), + (re.compile(r"省考"), "考试"), + (re.compile(r"考公"), "考试"), + (re.compile(r"事业编"), "考试"), + (re.compile(r"教资|教师资格"), "考试"), + (re.compile(r"法考"), "考试"), + (re.compile(r"注会"), "考试"), + (re.compile(r"一建|二建"), "考试"), + (re.compile(r"公务员"), "考试"), + (re.compile(r"选调"), "考试"), + (re.compile(r"世界杯"), "赛事"), + (re.compile(r"奥运会"), "赛事"), + (re.compile(r"亚运会"), "赛事"), + (re.compile(r"欧冠"), "赛事"), + (re.compile(r"欧洲杯"), "赛事"), + (re.compile(r"亚洲杯"), "赛事"), + (re.compile(r"全运会"), "赛事"), + (re.compile(r"世博会"), "展会"), + (re.compile(r"进博会"), "展会"), + (re.compile(r"广交会"), "展会"), + (re.compile(r"双十一|618"), "购物节"), + (re.compile(r"春运"), "民生"), + (re.compile(r"秋招|春招"), "招聘"), + (re.compile(r"报名"), "报名"), + (re.compile(r"录取"), "录取"), + (re.compile(r"分数线"), "分数"), + ] + + matched_event = None + for pattern, event_type in periodic_event_patterns: + if pattern.search(query): + matched_event = event_type + break + + if not matched_event: + return query + + # 清洗疑问词和冗余词,提取核心搜索词 + core_query = query + core_query = re.sub(r"^今天|^现在|^当前|^目前|^今年", "", core_query) + core_query = re.sub(r"什么时候|几号|几时|哪天|哪一天|是哪天|是几号", "", core_query) + core_query = re.sub(r"还有几天|还剩几天|还有多久|还差几天$", "", core_query) + core_query = re.sub(r"多少|怎么样|好不好|有没有|是否", "", core_query) + core_query = core_query.strip() + + if not core_query: + core_query = query + + from datetime import datetime + now = datetime.now() + year = now.year + month = now.month + + enriched = core_query + if not has_year: + enriched = f"{year}年" + enriched + + if not has_half and matched_event in ("考试", "报名", "录取", "分数", "招聘"): + # 高考/中考固定在6月举行,始终搜索上半年 + first_half_only = bool(re.search(r"高考|中考", core_query)) + if first_half_only: + enriched = enriched + "上半年" + elif month >= 5 and month <= 8: + enriched = enriched + "下半年" + elif month >= 9: + enriched = enriched + "下半年" + else: + enriched = enriched + "上半年" + + if matched_event in ("考试", "报名", "录取", "分数") and "时间" not in enriched: + enriched = enriched + "时间" + + # 优化顺序:将"下半年/上半年"移到事件名后面、时间前面 + # 例如 "河北软考下半年时间" → "下半年河北软考时间"(更自然的搜索词) + enriched = re.sub( + r"^(20\d{2}年)(.+?)(上半年|下半年)(时间)$", + r"\1\3\2\4", + enriched + ) + + return enriched + + +def get_search_confidence( + user_message: str, + conversation_history: list[dict] | None = None, +) -> str: + """获取搜索意图的置信度等级 + + 用于前端展示或日志分析,帮助理解分类决策。 + + 参数: + user_message: 用户原始消息 + conversation_history: 对话历史(可选) + + 返回: + "high" / "medium" / "low" / "none" + """ + score, _ = compute_search_score(user_message, conversation_history) + if score >= 8: + return "high" + elif score >= _SEARCH_THRESHOLD: + return "medium" + elif score >= 2: + return "low" + else: + return "none" + + +# ============================================================================= +# 直接运行验证 +# ============================================================================= +if __name__ == "__main__": + test_cases = [ + # 需要搜索 - 基础场景 + ("距离河北软考还有几天", True), + ("2026年世界杯在哪举办", True), + ("iPhone 18什么时候出", True), + ("特斯拉股价多少", True), + ("最近有什么好看的电影", True), + ("帮我搜索一下Python教程", True), + ("河北软考几号", True), + ("今年考研什么时候报名", True), + ("NBA总决赛比分", True), + ("北京到上海的高铁时刻表", True), + ("2026年软考时间安排", True), + ("最近有什么新闻", True), + ("推荐一个旅游景点", True), + ("明天北京天气怎么样", True), + + # 需要搜索 - 知识边界场景 + ("2025年有什么新政策", True), + ("2026年世博会在哪举办", True), + ("最新版本的ChatGPT是什么", True), + ("目前GPT-5出了吗", True), + ("今年国考什么时候报名", True), + ("当前人民币汇率是多少", True), + + # 需要搜索 - 实体识别场景 + ("GPT-5什么时候发布", True), + ("DeepSeek最新模型是什么", True), + ("Windows 12什么时候出", True), + ("2026年亚运会在哪办", True), + ("诺贝尔奖2025年得主是谁", True), + ("广东省最新落户政策", True), + + # 需要搜索 - 比较评价场景 + ("iPhone 16和华为Mate70哪个好", True), + ("比亚迪和特斯拉怎么选", True), + ("考研和考公哪个更值得", True), + ("笔记本电脑性价比排行", True), + ("React和Vue哪个好", True), + + # 需要搜索 - 事实特异性场景 + ("5月1号放假安排", True), + ("清华录取分数线多少", True), + ("北京故宫门票多少钱", True), + ("国家图书馆营业时间", True), + ("软考报名条件是什么", True), + ("GPT-4官网下载地址", True), + + # 需要搜索 - 实时数据场景 + ("今天油价多少", True), + ("黄金价格多少一克", True), + ("北京今天限行尾号", True), + ("上海二手房均价多少", True), + ("美元汇率多少", True), + ("最近有没有招聘会", True), + + # 需要搜索 - 否定覆盖场景 + ("什么是GPT-5", True), + ("什么是2025年新规", True), + + # 不需要搜索 + ("今天几号", False), + ("现在几点了", False), + ("3+5等于多少", False), + ("帮我写一段朋友圈文案", False), + ("你好,请介绍一下你自己", False), + ("Python怎么写快速排序", False), + ("什么是量子力学", False), + ("翻译一下这段话", False), + ("解释一下相对论", False), + ("今天天气不错", False), + ("几点吃饭", False), + ] + + print("=" * 80) + print(" SearchIntent 八维度搜索意图识别 测试结果") + print("=" * 80) + print() + + passed = 0 + failed = 0 + + for msg, expected in test_cases: + score, signals = compute_search_score(msg) + result = score >= _SEARCH_THRESHOLD + status = "PASS" if result == expected else "FAIL" + if status == "PASS": + passed += 1 + else: + failed += 1 + signal_str = ", ".join(signals[:4]) if signals else "无" + confidence = get_search_confidence(msg) + print(f" [{status}] {msg:40} | score={score:2d} | conf={confidence:6s} | " + f"预期={'搜' if expected else '不搜':2s} | 实际={'搜' if result else '不搜':2s} | {signal_str}") + + print() + print(f" 通过: {passed} 失败: {failed} 总计: {len(test_cases)}") + + if failed == 0: + print("\n 全部测试通过!") + else: + print(f"\n 有 {failed} 个测试未通过,需要调整规则") diff --git a/backend/app/utils/time_tool.py b/backend/app/utils/time_tool.py new file mode 100644 index 0000000..66b113b --- /dev/null +++ b/backend/app/utils/time_tool.py @@ -0,0 +1,1699 @@ +""" +本地时间工具 - 纯本地时间/日期/星期查询模块 + +功能: + 提供毫秒级的时间/日期/星期自然语言回复,纯本地计算,零网络依赖。 + 支持多种查询类型:time / date / week / date_offset / week_offset / + lunar / holiday / timezone / all + +核心能力: + - 口语化时间格式化:零X分、整点简化、12小时制+六段划分 + - 多轮对话追踪:1分钟/5分钟重复查询适配不同话术 + - 六段场景适配:凌晨/早上/上午/中午/下午/晚上/深夜 + - 工作日/周末/节假日自动识别+个性化问候 + - 情绪急迫语境识别:安抚+精准报时话术 + - 记忆系统联动:时区/所在地/职业/日程/生日/作息/偏好 + - 跨工具联动:天气数据/行程计划/时差提示 + - 多Agent风格:通用/闲聊/办公/旅游/创作 五种风格 + - 三级兜底:个性化→通用友好→极简报时→硬编码安全兜底 + +设计原则: + 1. @lru_cache 实现1分钟缓存,同一分钟内重复查询零开销 + 2. 所有规则/模板/时段/风格外提为可配置常量,修改无需动核心代码 + 3. 回复自然口语化,逐场景精细打磨,彻底去除 AI 感 + 4. 全局单例模式,避免重复实例化 + 5. 仅依赖 Python 标准库,零外部依赖 +""" + +import re +import time as _time_module +from datetime import datetime, timedelta, date +from functools import lru_cache +from typing import Optional +from zoneinfo import ZoneInfo, available_timezones + + +# ============================================================================= +# 星期映射表:将 Python 星期数字转为中文 +# ============================================================================= +_WEEKDAY_NAMES = { + 0: "星期一", + 1: "星期二", + 2: "星期三", + 3: "星期四", + 4: "星期五", + 5: "星期六", + 6: "星期日", +} + +_WEEKDAY_NAMES_SHORT = { + 0: "周一", + 1: "周二", + 2: "周三", + 3: "周四", + 4: "周五", + 5: "周六", + 6: "周日", +} + +# 周末集合:用于判断是否附加周末提示 +_WEEKEND_DAYS = {5, 6} + +# 中文数字映射 +_CN_NUM = { + "零": 0, "一": 1, "二": 2, "三": 3, "四": 4, + "五": 5, "六": 6, "七": 7, "八": 8, "九": 9, "十": 10, +} + +# 星期偏移中文映射 +_WEEKDAY_OFFSET_CN = { + "周一": 0, "周二": 1, "周三": 2, "周四": 3, + "周五": 4, "周六": 5, "周日": 6, +} + +# ============================================================================= +# 农历数据(2025-2030) +# 格式: (公历(年,月,日), 农历(年,月,日,闰月标识)) +# 数据来源:标准农历推算 +# ============================================================================= + +_LUNAR_YEAR_NAMES = { + 2025: "乙巳", 2026: "丙午", 2027: "丁未", + 2028: "戊申", 2029: "己酉", 2030: "庚戌", +} + +_LUNAR_MONTH_NAMES = [ + "", "正月", "二月", "三月", "四月", "五月", "六月", + "七月", "八月", "九月", "十月", "冬月", "腊月", +] + +_LUNAR_DAY_NAMES = [ + "", "初一", "初二", "初三", "初四", "初五", "初六", "初七", "初八", "初九", "初十", + "十一", "十二", "十三", "十四", "十五", "十六", "十七", "十八", "十九", "二十", + "廿一", "廿二", "廿三", "廿四", "廿五", "廿六", "廿七", "廿八", "廿九", "三十", +] + +# 农历每月初一对应的公历日期(2025-2028) +# 格式: (公历年, 公历月, 公历日, 农历年, 农历月, 是否闰月) +_LUNAR_MONTH_STARTS = [ + # 2025 乙巳年 + (2025, 1, 29, 2025, 1, False), + (2025, 2, 28, 2025, 2, False), + (2025, 3, 29, 2025, 3, False), + (2025, 4, 28, 2025, 4, False), + (2025, 5, 27, 2025, 5, False), + (2025, 6, 26, 2025, 6, False), + (2025, 7, 25, 2025, 6, True), # 闰六月 + (2025, 8, 23, 2025, 7, False), + (2025, 9, 22, 2025, 8, False), + (2025, 10, 21, 2025, 9, False), + (2025, 11, 20, 2025, 10, False), + (2025, 12, 19, 2025, 11, False), + # 2026 丙午年 — 2026年春节是2月17日 + (2026, 1, 18, 2025, 12, False), # 农历2025腊月初一 + (2026, 2, 17, 2026, 1, False), # 正月初一(春节) + (2026, 3, 19, 2026, 2, False), + (2026, 4, 17, 2026, 3, False), + (2026, 5, 16, 2026, 4, False), + (2026, 6, 15, 2026, 5, False), + (2026, 7, 14, 2026, 6, False), + (2026, 8, 13, 2026, 7, False), + (2026, 9, 11, 2026, 8, False), + (2026, 10, 11, 2026, 9, False), + (2026, 11, 9, 2026, 10, False), + (2026, 12, 8, 2026, 11, False), + # 2027 丁未年 — 2027年春节是2月6日 + (2027, 1, 6, 2026, 12, False), # 农历2026腊月初一 + (2027, 2, 6, 2027, 1, False), # 正月初一(春节) + (2027, 3, 7, 2027, 2, False), + (2027, 4, 5, 2027, 3, False), + (2027, 5, 5, 2027, 4, False), + (2027, 6, 4, 2027, 5, False), + (2027, 7, 3, 2027, 5, True), # 闰五月 + (2027, 8, 2, 2027, 6, False), + (2027, 8, 31, 2027, 7, False), + (2027, 9, 30, 2027, 8, False), + (2027, 10, 29, 2027, 9, False), + (2027, 11, 28, 2027, 10, False), + (2027, 12, 27, 2027, 11, False), + # 2028 戊申年 — 2028年春节是1月26日 + (2028, 1, 26, 2028, 1, False), # 正月初一(春节) + (2028, 2, 24, 2028, 2, False), + (2028, 3, 25, 2028, 3, False), + (2028, 4, 24, 2028, 4, False), + (2028, 5, 23, 2028, 5, False), + (2028, 6, 22, 2028, 6, False), + (2028, 7, 21, 2028, 7, False), + (2028, 8, 20, 2028, 8, False), + (2028, 9, 18, 2028, 9, False), + (2028, 10, 17, 2028, 10, False), + (2028, 11, 16, 2028, 11, False), + (2028, 12, 15, 2028, 12, False), +] + +# 法定节假日(公历固定日期 + 农历浮动日期) +# 格式: (月, 日, 名称, 是否公历) +_FIXED_HOLIDAYS = [ + (1, 1, "元旦", True), + (2, 14, "情人节", True), + (3, 8, "妇女节", True), + (3, 12, "植树节", True), + (4, 1, "愚人节", True), + (5, 1, "劳动节", True), + (5, 4, "青年节", True), + (6, 1, "儿童节", True), + (7, 1, "建党节", True), + (8, 1, "建军节", True), + (9, 10, "教师节", True), + (10, 1, "国庆节", True), + (10, 31, "万圣节", True), + (12, 25, "圣诞节", True), +] + +# ============================================================================= +# 查询类型识别正则(全部保留,不变) +# ============================================================================= + +_PATTERN_TIME = re.compile( + r"(几点|几时|什么时间|啥时间|现在时间|当前时间|看时间|报时|time|clock)", +) + +_PATTERN_DATE = re.compile( + r"(几号|几月几|几月几日|什么日期|今天日期|当前日期|今天几|啥日期|" + r"年月日|日历|几月份|几月$)", +) + +_PATTERN_WEEKDAY = re.compile( + r"(星期几|周几|礼拜几|今天周|明天周|后天周|昨天周|周五|周六|周日|" + r"周一|周二|周三|周四|星期[一二三四五六日天]|周[一二三四五六日天])", +) + +_PATTERN_DATE_OFFSET = re.compile( + r"(明天|今日|今日|后天|大后天|昨天|前天|大前天|" + r"\d+天后|\d+天前|[一二三四五六七八九十]+天后|[一二三四五六七八九十]+天前|" + r"下周[一二三四五六日天]|下下周|上周[一二三四五六日天])", +) + +_PATTERN_LUNAR = re.compile( + r"(农历|阴历|初一|十五|元宵|端午|中秋|重阳|除夕|腊月|大年)", +) + +_PATTERN_HOLIDAY = re.compile( + r"(什么日子|什么节日|什么节|法定节假日|节假日|有没有假|放不放假|" + r"过节|节日|庆祝|纪念日)", +) + +_PATTERN_TIMEZONE = re.compile( + r"[的地]时间" + r"|" + r"时间(?!点|分|钟|段|候|长|差)" + r"|" + r"时区|UTC|GMT|时差", +) + +# 急迫语境关键词:用于安抚+精准报时 +_PATTERN_URGENT = re.compile( + r"(来不及|快迟到|赶时间|赶车|赶飞机|赶高铁|赶火车|" + r"要出发|马上|快点|赶紧|加速|匆忙|急着)" +) + +# 时间偏移关键词:X小时后、X分钟前、X天后等 +_PATTERN_TIME_OFFSET = re.compile( + r"(?P\d+|[一二三四五六七八九十]+)" + r"(?P个?(?:小时|分钟|天|钟头))" + r"(?P[后前]|之后|之前|了)" +) + + +# ============================================================================= +# 时间偏移解析 —— 支持口语化时间偏移查询 +# ============================================================================= + +def parse_time_offset(user_message: str) -> dict: + """解析用户输入的口语化时间偏移指令 + + 支持格式: + - X小时后、X个小时后、X分钟后、X天后 + - X小时前、X分钟前、X天前 + - X小时之后、X分钟之前 + - 中文数字:一小时后、三十分钟后 + + 参数: + user_message: 用户原始消息 + + 返回: + 字典: + - "value": 偏移数值(int) + - "unit": 单位("小时"/"分钟"/"天") + - "direction": 方向("后"/"前") + - "valid": 是否解析成功 + - "error": 解析失败时的提示信息 + """ + if not user_message: + return {"valid": False, "error": "消息为空"} + + cleaned = _clean_input(user_message) + + # 匹配偏移模式 + match = _PATTERN_TIME_OFFSET.search(cleaned) + if not match: + return {"valid": False, "error": "未识别到时间偏移指令"} + + num_raw = match.group("num") + unit_raw = match.group("unit") + dir_raw = match.group("dir") + + # 解析数值 + if num_raw.isdigit(): + value = int(num_raw) + else: + value = 0 + for ch in num_raw: + if ch in _CN_NUM: + value += _CN_NUM[ch] + + if value <= 0: + return {"valid": False, "error": "偏移数值必须大于0"} + + # 解析单位 + unit = "小时" + if "分钟" in unit_raw or "分" in unit_raw: + unit = "分钟" + elif "天" in unit_raw or "日" in unit_raw: + unit = "天" + elif "小时" in unit_raw or "钟头" in unit_raw or "时" in unit_raw: + unit = "小时" + + # 解析方向 + direction = "后" + if "前" in dir_raw: + direction = "前" + + return { + "valid": True, + "value": value, + "unit": unit, + "direction": direction, + "error": None, + } + + +def calc_offset_time(offset_info: dict, timezone: str = "Asia/Shanghai") -> str: + """根据偏移参数计算目标时间,返回自然语言结果 + + 参数: + offset_info: parse_time_offset 返回的字典 + timezone: 时区标识符 + + 返回: + 自然语言时间回复,如"1小时后是上午11点35分哦" + + 异常安全: + 偏移信息无效时,自动降级为当前时间查询 + """ + if not offset_info.get("valid"): + # 降级为当前时间 + now = datetime.now(ZoneInfo(timezone)) + _, _, time_str = TimeTool._format_time_oral(now.hour, now.minute) + return f"现在是{time_str}" + + value = offset_info["value"] + unit = offset_info["unit"] + direction = offset_info["direction"] + + # 计算目标时间 + now = datetime.now(ZoneInfo(timezone)) + if direction == "后": + if unit == "小时": + target = now + timedelta(hours=value) + elif unit == "分钟": + target = now + timedelta(minutes=value) + else: # 天 + target = now + timedelta(days=value) + else: # 前 + if unit == "小时": + target = now - timedelta(hours=value) + elif unit == "分钟": + target = now - timedelta(minutes=value) + else: # 天 + target = now - timedelta(days=value) + + # 格式化目标时间 + _, _, time_str = TimeTool._format_time_oral(target.hour, target.minute) + + # 判断日期是否变化 + date_changed = target.date() != now.date() + day_offset = (target.date() - now.date()).days + + if date_changed: + if day_offset == 1: + return f"{value}{unit}{direction}是明天{time_str}" + elif day_offset == -1: + return f"{value}{unit}{direction}是昨天{time_str}" + elif day_offset > 0: + return f"{value}{unit}{direction}是{day_offset}天后{time_str}" + else: + return f"{value}{unit}{direction}是{abs(day_offset)}天前{time_str}" + + return f"{value}{unit}{direction}是{time_str}" + + +def _is_time_offset_query(cleaned: str) -> bool: + """判断用户消息是否为时间偏移查询""" + return bool(_PATTERN_TIME_OFFSET.search(cleaned)) + + +def _clean_input(text: str) -> str: + """清洗输入文本,去除空格和中英文问号""" + cleaned = text.replace(" ", "").replace(" ", "") + cleaned = cleaned.replace("?", "").replace("?", "") + return cleaned + + +def _detect_query_type(cleaned: str) -> str: + """根据清洗后的用户消息,识别时间查询的子类型 + + 返回: + "time" - 当前时间 + "date" - 当前日期 + "week" - 星期几 + "date_offset" - 偏移日期(明天几号) + "week_offset" - 偏移星期(后天周几) + "lunar" - 农历日期 + "holiday" - 节假日 + "timezone" - 指定时区时间 + "all" - 综合信息 + """ + has_time = bool(_PATTERN_TIME.search(cleaned)) + has_date = bool(_PATTERN_DATE.search(cleaned)) + has_week = bool(_PATTERN_WEEKDAY.search(cleaned)) + has_offset = bool(_PATTERN_DATE_OFFSET.search(cleaned)) + has_lunar = bool(_PATTERN_LUNAR.search(cleaned)) + has_holiday = bool(_PATTERN_HOLIDAY.search(cleaned)) + has_timezone = bool(_PATTERN_TIMEZONE.search(cleaned)) + + if has_lunar: + return "lunar" + if has_holiday: + return "holiday" + if has_timezone and has_time: + return "timezone" + if has_offset and (has_date or has_week): + if has_date: + return "date_offset" + return "week_offset" + if has_offset: + return "date_offset" + + match_count = sum([has_time, has_date, has_week]) + if match_count == 1: + if has_time: + return "time" + if has_date: + return "date" + if has_week: + return "week" + + return "all" + + +def _extract_day_offset(cleaned: str, now: datetime | None = None) -> int: + """从用户消息中提取日期偏移量 + + 支持格式: + "明天"=1, "后天"=2, "大后天"=3, + "昨天"=-1, "前天"=-2, + "下周一"=next_monday, "3天后"=3 + + 参数: + cleaned: 清洗后的用户消息 + now: 当前时间(时区感知),None 时使用 datetime.now() + + 返回: + 距离今天的偏移天数,无法提取时返回 0 + """ + if "昨天" in cleaned: + return -1 + if "前天" in cleaned: + return -2 + if "大前天" in cleaned: + return -3 + if "明天" in cleaned or "明日" in cleaned: + return 1 + if "后天" in cleaned: + return 2 + if "大后天" in cleaned: + return 3 + + offset_match = re.search(r"(\d+|[一二三四五六七八九十]+)\s*天?(后|前)", cleaned) + if offset_match: + raw = offset_match.group(1) + direction = offset_match.group(2) + if raw.isdigit(): + num = int(raw) + else: + num = 0 + for ch in raw: + if ch in _CN_NUM: + num += _CN_NUM[ch] + if direction == "前": + return -num + return num + + for week_word, weekday_idx in _WEEKDAY_OFFSET_CN.items(): + if week_word in cleaned: + today_weekday = (now or datetime.now()).weekday() + days_until = (weekday_idx - today_weekday) % 7 + is_next = "下" in cleaned or "下周" in cleaned or "下礼拜" in cleaned + if is_next and days_until == 0: + days_until = 7 + return days_until + + return 0 + + +def _solar_to_lunar(solar: date) -> dict: + """公历转农历""" + solar_tuple = (solar.year, solar.month, solar.day) + prev_start = None + for start in _LUNAR_MONTH_STARTS: + start_solar = (start[0], start[1], start[2]) + if start_solar <= solar_tuple: + prev_start = start + else: + break + if prev_start is None: + return {"found": False} + + prev_date = date(prev_start[0], prev_start[1], prev_start[2]) + delta_days = (solar - prev_date).days + lunar_year = prev_start[3] + lunar_month = prev_start[4] + is_leap = prev_start[5] if len(prev_start) > 5 else False + lunar_day = delta_days + 1 + + year_name = _LUNAR_YEAR_NAMES.get(lunar_year, "") + month_name = (_LUNAR_MONTH_NAMES[lunar_month] + if 1 <= lunar_month <= 12 else f"{lunar_month}月") + if is_leap: + month_name = "闰" + month_name + day_name = (_LUNAR_DAY_NAMES[lunar_day] + if 1 <= lunar_day < len(_LUNAR_DAY_NAMES) else f"{lunar_day}日") + spring_date = _get_spring_festival_date(lunar_year) + + return { + "found": True, + "lunar_year": lunar_year, + "lunar_month": lunar_month, + "lunar_day": lunar_day, + "is_leap": is_leap, + "year_name": year_name, + "month_name": month_name, + "day_name": day_name, + "spring_date": spring_date, + } + + +def _get_spring_festival_date(lunar_year: int) -> date | None: + """获取指定农历年春节(正月初一)的公历日期""" + for start in _LUNAR_MONTH_STARTS: + if start[3] == lunar_year and start[4] == 1 and ( + len(start) <= 5 or not start[5]): + return date(start[0], start[1], start[2]) + return None + + +def _get_holiday_info(solar: date, lunar_info: dict) -> str: + """获取法定节假日信息,无节假日返回空字符串""" + holidays = [] + month, day = solar.month, solar.day + for (m, d, name, _) in _FIXED_HOLIDAYS: + if m == month and d == day: + holidays.append(name) + + if lunar_info.get("found"): + lm = lunar_info["lunar_month"] + ld = lunar_info["lunar_day"] + if lm == 1 and ld == 1: + holidays.append("春节") + if lm == 1 and ld == 15: + holidays.append("元宵节") + if month == 4 and day in (4, 5): + holidays.append("清明节") + if lm == 5 and ld == 5: + holidays.append("端午节") + if lm == 7 and ld == 7: + holidays.append("七夕节") + if lm == 8 and ld == 15: + holidays.append("中秋节") + if lm == 9 and ld == 9: + holidays.append("重阳节") + if lm == 12 and ld == 30: + holidays.append("除夕") + elif lm == 12 and ld == 29: + for start in _LUNAR_MONTH_STARTS: + if start[3] == lunar_info["lunar_year"] and start[4] == 12: + next_idx = _LUNAR_MONTH_STARTS.index(start) + 1 + if next_idx < len(_LUNAR_MONTH_STARTS): + nxt = _LUNAR_MONTH_STARTS[next_idx] + cur = date(start[0], start[1], start[2]) + nxt_d = date(nxt[0], nxt[1], nxt[2]) + if (nxt_d - cur).days == 29: + holidays.append("除夕") + return "、".join(holidays) + + +def _get_holiday_message(holiday_name: str) -> str: + """根据节假日名称生成祝福语""" + holiday_messages = { + "元旦": "祝你元旦快乐,新年新气象!", + "春节": "祝你春节快乐,阖家幸福,万事如意!", + "元宵节": "元宵节快乐,记得吃汤圆哦~", + "情人节": "情人节快乐!", + "妇女节": "祝你节日快乐!", + "植树节": "植树节,一起爱护地球吧~", + "清明节": "清明时节雨纷纷,注意出行安全。", + "劳动节": "劳动节快乐,辛苦了!", + "青年节": "青年节快乐,保持年轻心态!", + "端午节": "端午节快乐,记得吃粽子~", + "儿童节": "儿童节快乐,保持童心!", + "七夕节": "七夕节快乐!", + "中秋节": "中秋节快乐,花好月圆人团圆!", + "国庆节": "国庆节快乐!", + "重阳节": "重阳节快乐,登高望远心情好~", + "万圣节": "万圣节快乐~", + "圣诞节": "圣诞节快乐!", + "除夕": "除夕快乐,辞旧迎新!", + } + return holiday_messages.get(holiday_name, f"{holiday_name}快乐!") + + +# ============================================================================= +# 时区名称映射 —— 常见城市到时区 +# ============================================================================= + +_CITY_TIMEZONE_MAP = { + "北京": "Asia/Shanghai", "上海": "Asia/Shanghai", "广州": "Asia/Shanghai", + "深圳": "Asia/Shanghai", "杭州": "Asia/Shanghai", "成都": "Asia/Shanghai", + "西安": "Asia/Shanghai", "重庆": "Asia/Shanghai", "武汉": "Asia/Shanghai", + "南京": "Asia/Shanghai", "苏州": "Asia/Shanghai", "天津": "Asia/Shanghai", + "香港": "Asia/Hong_Kong", "澳门": "Asia/Macau", "台北": "Asia/Taipei", + "东京": "Asia/Tokyo", "大阪": "Asia/Tokyo", "北海道": "Asia/Tokyo", + "首尔": "Asia/Seoul", "釜山": "Asia/Seoul", + "新加坡": "Asia/Singapore", "曼谷": "Asia/Bangkok", + "吉隆坡": "Asia/Kuala_Lumpur", + "雅加达": "Asia/Jakarta", "马尼拉": "Asia/Manila", + "河内": "Asia/Ho_Chi_Minh", + "新德里": "Asia/Kolkata", "孟买": "Asia/Kolkata", + "科伦坡": "Asia/Colombo", + "迪拜": "Asia/Dubai", "利雅得": "Asia/Riyadh", + "德黑兰": "Asia/Tehran", + "伦敦": "Europe/London", "巴黎": "Europe/Paris", + "柏林": "Europe/Berlin", + "罗马": "Europe/Rome", "马德里": "Europe/Madrid", + "莫斯科": "Europe/Moscow", + "阿姆斯特丹": "Europe/Amsterdam", + "斯德哥尔摩": "Europe/Stockholm", + "纽约": "America/New_York", "洛杉矶": "America/Los_Angeles", + "芝加哥": "America/Chicago", "多伦多": "America/Toronto", + "温哥华": "America/Vancouver", "旧金山": "America/Los_Angeles", + "悉尼": "Australia/Sydney", "墨尔本": "Australia/Melbourne", + "奥克兰": "Pacific/Auckland", +} + +_TIMEZONE_KEYWORDS = { + "东八区": "Asia/Shanghai", "北京时间": "Asia/Shanghai", + "东京时间": "Asia/Tokyo", "日本时间": "Asia/Tokyo", + "首尔时间": "Asia/Seoul", "韩国时间": "Asia/Seoul", + "新加坡时间": "Asia/Singapore", "曼谷时间": "Asia/Bangkok", + "伦敦时间": "Europe/London", "英国时间": "Europe/London", + "巴黎时间": "Europe/Paris", "法国时间": "Europe/Paris", + "纽约时间": "America/New_York", "美国东部时间": "America/New_York", + "洛杉矶时间": "America/Los_Angeles", "美国西部时间": "America/Los_Angeles", + "悉尼时间": "Australia/Sydney", "澳洲时间": "Australia/Sydney", +} + + +def _detect_timezone_from_message(cleaned: str) -> str | None: + """从用户消息中提取时区信息""" + for keyword, tz in _TIMEZONE_KEYWORDS.items(): + if keyword in cleaned: + return tz + for city, tz in _CITY_TIMEZONE_MAP.items(): + if city in cleaned: + return tz + return None + + +# ============================================================================= +# ╔══════════════════════════════════════════════════════════════════════════╗ +# ║ 可配置常量 — 时段划分 / 问候语 / Agent风格 / 多轮对话阈值 ║ +# ║ 修改这些常量即可调整回复风格,无需改核心代码 ║ +# ╚══════════════════════════════════════════════════════════════════════════╝ +# ============================================================================= + +# ---- 时段划分(24小时制,左闭右开)---- +# 通过 (开始小时, 结束小时, 时段名, 12小时制偏移, 问候前缀, 后缀提醒) 描述 +_PERIODS = [ + # (start, end, name, display_offset, greetings_tuple, reminder) + (0, 6, "凌晨", 0, ("夜深了,", "凌晨好,"), "早点休息哦"), + (6, 9, "早上", 0, ("早上好,", "早安,新的一天开始啦,"), ""), + (9, 12, "上午", 0, ("上午好,", ""), ""), + (12, 14, "中午", 12, ("中午好,", ""), "别忘了按时吃饭~"), + (14, 18, "下午", 12, ("下午好,", ""), ""), + (18, 22, "晚上", 12, ("晚上好,", "傍晚好,"), ""), + (22, 24, "深夜", 12, ("夜深了,", ""), "早点休息哦"), +] + +# ---- 多轮对话阈值(秒)---- +_REPEAT_SAME_MINUTE = 60 # 1分钟内重复查询 → "还是XX时间哦" +_REPEAT_NEAR_MINUTE = 120 # 2分钟内重复查询 → "距离上次才过了X分钟" +_REPEAT_MAX_WINDOW = 180 # 超过3分钟视为正常查询(大于 NEAR 阈值) + +# ---- 工作日/周末场景化后缀 ---- +_WORKDAY_MOTIVATIONS = [ + "加油干,今天也是元气满满的一天!", + "搬砖时间到,一起加油吧~", + "新的一天,新的开始!", + "认真工作的你最帅/最美!", +] + +_WEEKEND_RELAXATIONS = [ + "周末愉快,好好享受休息时光~", + "周末啦,今天有什么计划吗?", + "周末是充电的好时机,放松一下吧~", +] + +_FRIDAY_CELEBRATION = "明天就是周末啦,再坚持一下~" + +# ---- 时段默认后缀池(每个时段都有,不再依赖reminder字段)---- +# 格式: {时段名: [后缀1, 后缀2, ...]} +_PERIOD_DEFAULT_SUFFIXES = { + "凌晨": [ + "早点休息,身体最重要", + "熬夜伤身,快去睡吧", + "这个点还没睡,是在加班吗?", + ], + "早上": [ + "新的一天开始了,精神点!", + "早餐吃了吗?", + "今天也要加油哦~", + ], + "上午": [ + "上午效率最高,抓紧干活!", + "工作/学习顺利吗?", + "记得适当休息,别一直盯着屏幕~", + ], + "中午": [ + "午饭吃了吗?", + "午休一下,下午更有精神", + "别吃太饱,容易犯困哈哈", + ], + "下午": [ + "下午容易犯困,来杯咖啡提提神?", + "再坚持一下,很快就下班了", + "工作/学习还顺利吗?", + ], + "晚上": [ + "晚饭吃了吗?", + "晚上是属于自己的时间,好好放松", + "今天过得怎么样?", + ], + "深夜": [ + "还不睡?明天还要早起呢", + "熬夜对皮肤不好哦", + "快去休息吧,晚安~", + ], +} + + +def _get_period_default_suffix(hour: int, scene_ctx: dict) -> str: + """根据当前时段返回默认场景化后缀 + + 参数: + hour: 当前小时(0-23) + scene_ctx: 场景上下文 + + 返回: + 随机选择的时段后缀,或空字符串 + """ + import random + # 匹配当前时段 + period_name = "深夜" + for (start, end, p_name, _, _, _) in _PERIODS: + if start <= hour < end: + period_name = p_name + break + + # 根据用户作息偏好调整 + wake_time = scene_ctx.get("user_wake_time", "") + sleep_time = scene_ctx.get("user_sleep_time", "") + + # 如果用户设置了起床时间,且当前接近起床时间,添加特殊提示 + if wake_time and period_name == "早上": + try: + wake_hour = int(wake_time.split(":")[0]) + if abs(hour - wake_hour) <= 1: + return "该起床啦,别赖床哦~" + except (ValueError, IndexError): + pass + + # 如果用户设置了睡觉时间,且当前接近睡觉时间,添加特殊提示 + if sleep_time and period_name in ("晚上", "深夜"): + try: + sleep_hour = int(sleep_time.split(":")[0]) + if hour >= sleep_hour: + return "该准备睡觉啦,晚安~" + except (ValueError, IndexError): + pass + + suffixes = _PERIOD_DEFAULT_SUFFIXES.get(period_name, []) + if suffixes: + return random.choice(suffixes) + return "" + + +# ---- Agent 类型回复风格配置 ---- +# 格式: {类型: {prefix: 前缀模板, suffix: 后缀模板, max_length: 最大字数, tone: 风格名}} +# 模板中 {time_str} 会被替换为实际时间 +# 模板中 {greeting} 会被替换为时段问候 +_AGENT_STYLES = { + "通用": { # general — 默认友好分时段 + "prefix": "{greeting}现在是{time_str}哦~", + "suffix": "", + "max_length": 60, + "tone": "友好自然", + }, + "闲聊": { # casual — 轻松生活化+互动感 + "prefix": "{greeting}现在已经是{time_str}啦~", + "suffix": "你在干嘛呢?", + "max_length": 50, + "tone": "轻松生活", + }, + "办公": { # office — 极简精准,无多余话术 + "prefix": "{time_str}", + "suffix": "", + "max_length": 20, + "tone": "极简精准", + }, + "旅游": { # travel — 结合天气/行程/目的地 + "prefix": "{greeting}现在是{time_str}~", + "suffix": "旅途愉快!", + "max_length": 80, + "tone": "旅行友好", + }, + "创作": { # creative — 文艺感表达 + "prefix": "{greeting}时光流转,已是{time_str}。", + "suffix": "灵感来了吗?", + "max_length": 60, + "tone": "文艺清新", + }, +} + +_DEFAULT_AGENT = "通用" + + +# ============================================================================= +# TimeTool 类 +# ============================================================================= + +class TimeTool: + """本地时间工具类,封装时间获取与自然语言回复生成 + + 核心方法: + - get_reply_with_context() — 完整个性化回复入口(推荐) + - get_reply() — 基础回复入口(向后兼容) + - load_user_memory() — 从记忆系统加载配置 + + 用法: + tool = TimeTool(timezone="Asia/Shanghai") + reply = tool.get_reply_with_context("time", user_context={...}) + """ + + # ---- 类级别多轮对话状态 ---- + _last_query_time: float = 0.0 + _last_query_message: str = "" + + def __init__(self, timezone: str = "Asia/Shanghai", + agent_id: str | None = None): + """初始化时间工具 + + 参数: + timezone: 时区标识符,默认东八区,无效时回退 Asia/Shanghai + agent_id: 用户标识,用于从记忆系统读取个性化配置 + """ + if timezone not in available_timezones(): + timezone = "Asia/Shanghai" + self._timezone = ZoneInfo(timezone) + self._timezone_name = timezone + self._agent_id = agent_id + self._user_location = "" + self._user_profile: dict = {} # 来自记忆系统的完整用户画像 + self._user_schedule: list = [] # 用户日程 + self._user_preferences: dict = {} # 用户偏好 + # 实例级多轮状态(每个tool实例独立追踪) + self._instance_last_query_time: float = 0.0 + + # ------------------------------------------------------------------ + # 记忆系统对接(保留) + # ------------------------------------------------------------------ + + def load_user_memory(self, agent_id: str | None = None): + """从用户记忆系统加载个性化配置 + + 读取用户的时区、所在地、职业、日程、作息、偏好等信息, + 存入内部状态供回复生成使用。 + 记忆读取失败时保持默认配置,不影响基础功能。 + + 参数: + agent_id: 用户标识 + """ + if agent_id: + self._agent_id = agent_id + if not self._agent_id: + return + try: + from app.engines.memory.core import get_memory_storage + storage = get_memory_storage() + memory = storage.load(self._agent_id) + profile = memory.profile + if profile.timezone and profile.timezone in available_timezones(): + if profile.timezone != self._timezone_name: + self._timezone = ZoneInfo(profile.timezone) + self._timezone_name = profile.timezone + if profile.location: + self._user_location = profile.location + self._user_profile = { + "name": getattr(profile, "name", ""), + "location": getattr(profile, "location", ""), + "timezone": getattr(profile, "timezone", ""), + "occupation": getattr(profile, "occupation", ""), + "birthday": getattr(profile, "birthday", ""), + } + self._user_schedule = getattr(memory, "schedule", []) or [] + self._user_preferences = { + "reply_style": getattr(profile, "reply_style", "友好"), + "wake_time": getattr(profile, "wake_time", ""), + "sleep_time": getattr(profile, "sleep_time", ""), + } + except Exception: + pass + + # ------------------------------------------------------------------ + # 时间获取(保留,核心链路不动) + # ------------------------------------------------------------------ + + @staticmethod + @lru_cache(maxsize=64) + def _get_cached_now(minute_bucket: str, tz_name: str) -> datetime: + """带缓存的时间获取方法(按时区隔离缓存)""" + from datetime import timezone as _tz + tz = ZoneInfo(tz_name) + return datetime.now(tz) + + def _now(self) -> datetime: + """获取当前时间(带时区转换和1分钟缓存)""" + minute_bucket = datetime.now(self._timezone).strftime("%Y%m%d%H%M") + return self._get_cached_now(minute_bucket, self._timezone_name) + + # ================================================================== + # 回复生成全链路(本次全面重写) + # ================================================================== + + # ------------------------------------------------------------------ + # 口语化时间格式化 —— 核心格式化器 + # ------------------------------------------------------------------ + + @staticmethod + def _format_time_oral(hour: int, minute: int) -> tuple[str, str, str]: + """将24小时制时间转为口语化中文表达 + + 返回: (时段名, 12小时制小时数, 口语化分秒字符串) + 示例: + 9:05 → ("上午", 9, "9点零5分") + 12:00 → ("中午", 12, "12点整") + 20:30 → ("晚上", 8, "8点30分") + 0:10 → ("凌晨", 12, "12点零10分") + + 参数: + hour: 24小时制小时(0-23) + minute: 分钟(0-59) + + 返回: + (period_name, display_hour, time_str) + """ + # 匹配时段配置 + period_name = "深夜" + display_offset = 12 + for (start, end, p_name, offset, _, _rem) in _PERIODS: + if start <= hour < end or (start == 0 and hour == 0): + period_name = p_name + display_offset = offset + break + + # 12小时制转换 + if hour == 0: + display_hour = 12 + elif hour <= 12: + display_hour = hour if hour != 12 else 12 + else: + display_hour = hour - 12 + if display_offset != 0: + display_hour = hour - display_offset + if display_hour <= 0: + display_hour += 12 + + # 分钟口语化 + if minute == 0: + time_str = f"{period_name}{display_hour}点整" + elif minute < 10: + time_str = f"{period_name}{display_hour}点零{minute}分" + else: + time_str = f"{period_name}{display_hour}点{minute}分" + + return period_name, display_hour, time_str + + # ------------------------------------------------------------------ + # 多轮对话检测 + # ------------------------------------------------------------------ + + def _check_repeat_query(self, now_ts: float) -> str | None: + """检测是否重复查询时间,返回多轮对话提示 + + 返回值: + None — 非重复查询,正常生成回复 + 非空字符串 — 重复查询,直接返回此提示 + + 规则: + 同一分钟内 → "还是XX时间哦,才过了不到一分钟~" + 5分钟以内 → "现在是XX时间,距离上次问才过了X分钟" + """ + if self._instance_last_query_time == 0: + return None + + elapsed = now_ts - self._instance_last_query_time + + if elapsed <= _REPEAT_SAME_MINUTE: + now = self._now() + _, _, time_str = self._format_time_oral(now.hour, now.minute) + return f"还是{time_str}哦,才过了不到一分钟~" + + if elapsed <= _REPEAT_NEAR_MINUTE: + now = self._now() + _, _, time_str = self._format_time_oral(now.hour, now.minute) + mins = int(elapsed // 60) + mins_text = "1分钟" if mins <= 1 else f"{mins}分钟" + return f"现在是{time_str},距离你上次问才过了{mins_text}~" + + return None + + # ------------------------------------------------------------------ + # 场景化上下文构建 + # ------------------------------------------------------------------ + + def _build_scene_context(self, now: datetime, user_message: str = "", + user_context: dict | None = None) -> dict: + """构建场景化上下文,整合所有维度信息 + + 返回字典包含: + is_weekend, is_friday, holiday_name, is_urgent, + weather_data, travel_info, timezone_diff + + 参数: + now: 当前时间 + user_message: 用户消息(用于急迫语境检测) + user_context: 外部传入的用户上下文(memory/weather/travel数据) + """ + ctx: dict = { + "is_weekend": now.weekday() in _WEEKEND_DAYS, + "is_friday": now.weekday() == 4, + "holiday_name": "", + "is_urgent": False, + "weather_data": None, + "travel_info": None, + "timezone_diff": 0, + "user_location": self._user_location, + "user_occupation": self._user_profile.get("occupation", ""), + } + + # 节假日 + lunar_info = _solar_to_lunar(now.date()) + holiday = _get_holiday_info(now.date(), lunar_info) + if holiday: + ctx["holiday_name"] = holiday.split("、")[0] + + # 急迫语境检测 + if user_message and _PATTERN_URGENT.search(_clean_input(user_message)): + ctx["is_urgent"] = True + + # 外部上下文(天气/行程/时区差) + if user_context: + ctx["weather_data"] = user_context.get("weather") + ctx["travel_info"] = user_context.get("travel") + if user_context.get("remote_timezone"): + try: + remote_tz = ZoneInfo(user_context["remote_timezone"]) + remote_now = datetime.now(remote_tz) + local_now = datetime.now(self._timezone) + diff_hours = (remote_now.utcoffset().total_seconds() - + local_now.utcoffset().total_seconds()) / 3600 + ctx["timezone_diff"] = int(diff_hours) + except Exception: + pass + + return ctx + + # ------------------------------------------------------------------ + # 时段问候 + 场景后缀 + # ------------------------------------------------------------------ + + @staticmethod + def _get_greeting(hour: int, scene_ctx: dict, agent_type: str) -> str: + """根据时段和场景生成问候前缀 + + 优先级:急迫安抚 > 办公极简 > 分时段问候 + 节假日问候不再覆盖时段问候,而是叠加到后缀中 + """ + # 办公型无问候 + if agent_type == "办公": + return "" + + # 急迫语境 → 安抚话术(前缀即完整开头) + if scene_ctx.get("is_urgent"): + now_ts = datetime.now() + _, _, time_str = TimeTool._format_time_oral(now_ts.hour, now_ts.minute) + return f"别慌别慌,现在是{time_str}" + + # 分时段问候(节假日不在这里处理,放到后缀中叠加) + for (start, end, _, _, greetings, _reminder) in _PERIODS: + if start <= hour < end: + return greetings[0] + return "" + + @staticmethod + def _get_scene_suffix(hour: int, scene_ctx: dict, agent_type: str) -> str: + """生成场景化后缀 —— 互斥选择,单次回复仅1个最适配短句 + + 按优先级从高到低遍历场景,命中第一个即返回,禁止叠加: + 1. 急迫安抚 + 2. 行程提醒 + 3. 节假日问候 + 4. 周末/周五 + 5. 天气联动 + 6. 时段默认后缀 + 7. 工作日激励(兜底) + + 参数: + hour: 当前小时 + scene_ctx: 场景上下文 + agent_type: Agent类型 + + 返回: + 单个场景短句,或空字符串 + """ + import random + + # ---- 1. 急迫安抚(最高优先级)---- + if scene_ctx.get("is_urgent"): + return "深呼吸,别着急,来得及的" + + # ---- 2. 行程提醒 ---- + travel = scene_ctx.get("travel_info") + if travel: + t_time = travel.get("time", "") + t_type = travel.get("type", "") + if t_time: + try: + t_dt = datetime.fromisoformat(t_time) + t_str = t_dt.strftime("%H:%M") + remaining = t_dt - datetime.now() + if 0 < remaining.total_seconds() < 7200: + hours_left = int(remaining.total_seconds() // 3600) + mins_left = int((remaining.total_seconds() % 3600) // 60) + parts = [] + if hours_left > 0: + parts.append(f"{hours_left}小时") + if mins_left > 0: + parts.append(f"{mins_left}分钟") + if t_type: + return (f"距离你{t_str}的{t_type}还有" + f"{''.join(parts)},记得提前出发") + except Exception: + pass + + # ---- 3. 节假日问候 ---- + holiday = scene_ctx.get("holiday_name", "") + if holiday: + return _get_holiday_message(holiday) + + # ---- 4. 周末/周五 ---- + if scene_ctx.get("is_weekend"): + return random.choice(_WEEKEND_RELAXATIONS) + if scene_ctx.get("is_friday"): + return _FRIDAY_CELEBRATION + + # ---- 5. 天气联动 ---- + weather = scene_ctx.get("weather_data") + if weather and agent_type not in ("办公",): + w_desc = weather.get("desc", "") + temp = weather.get("temp", "") + if w_desc and "雨" in w_desc: + return "今天有雨,出门记得带伞" + if temp: + try: + t = float(temp) if isinstance(temp, (int, float)) else 20 + if isinstance(temp, str): + t = float(temp.replace("°C", "").replace("℃", "")) + if t <= 8: + return "外面挺冷的,多穿点" + except (ValueError, TypeError): + pass + + # ---- 6. 时段默认后缀 ---- + period_suffix = _get_period_default_suffix(hour, scene_ctx) + if period_suffix: + return period_suffix + + # ---- 7. 工作日激励(兜底)---- + if (not scene_ctx.get("is_weekend") + and not scene_ctx.get("is_friday") + and not holiday + and agent_type not in ("办公",)): + return random.choice(_WORKDAY_MOTIVATIONS) + + return "" + + # ------------------------------------------------------------------ + # Agent 风格适配 + # ------------------------------------------------------------------ + + @staticmethod + def _apply_agent_style(time_str: str, greeting: str, suffix: str, + agent_type: str) -> str: + """根据 Agent 类型组装最终回复""" + style = _AGENT_STYLES.get(agent_type, _AGENT_STYLES[_DEFAULT_AGENT]) + prefix_tpl = style["prefix"] + suffix_tpl = style["suffix"] + + # 渲染前缀 + prefix = prefix_tpl.format(greeting=greeting, time_str=time_str) + + # 拼接 + parts = [prefix] + if suffix: + parts.append(suffix) + if suffix_tpl and agent_type != "通用": + parts.append(suffix_tpl) + + return "".join(parts) + + # ------------------------------------------------------------------ + # 三级兜底机制 + # ------------------------------------------------------------------ + + def _generate_with_fallback(self, now: datetime, user_message: str, + user_context: dict | None, + agent_type: str) -> str: + """三级兜底回复生成 + + 判断顺序(关键修复:特殊查询优先于重复提问): + 1. 时间偏移查询(1小时后几点)→ 直接计算偏移时间 + 2. 日期/时区/农历等特殊查询 → 走对应专用逻辑 + 3. 重复提问检测(2分钟窗口)→ 短话术回复 + 4. 正常场景化回复 → 互斥选择1个场景短句 + + 一级:完整个性化场景化回复(含记忆/天气/行程联动) + 二级:通用友好分时段回复(降级,不需要额外数据) + 三级:极简精准报时(最终安全兜底,只报时间不报日期) + """ + try: + # ---- 步骤1:时间偏移查询优先判断 ---- + cleaned = _clean_input(user_message) if user_message else "" + if cleaned and _is_time_offset_query(cleaned): + offset_info = parse_time_offset(user_message) + if offset_info.get("valid"): + return calc_offset_time( + offset_info, timezone=self._timezone_name) + # 偏移解析失败 → 降级为当前时间 + _, _, time_str = self._format_time_oral(now.hour, now.minute) + return f"现在是{time_str}" + + # ---- 步骤2:日期/时区/农历等特殊查询 ---- + query_type = _detect_query_type(cleaned) + if query_type in ("date", "date_offset", "week", "week_offset", + "lunar", "holiday", "timezone"): + return self.get_reply(query_type, user_message=user_message) + + # ---- 步骤3:多轮对话检测(2分钟窗口)---- + now_ts = _time_module.time() + repeat_msg = self._check_repeat_query(now_ts) + if repeat_msg: + self._instance_last_query_time = now_ts + return repeat_msg + + # 更新时间记录 + self._instance_last_query_time = now_ts + + # ---- 步骤4:正常场景化回复 ---- + scene_ctx = self._build_scene_context( + now, user_message, user_context) + + # 口语化格式化 + _, _, time_str = self._format_time_oral(now.hour, now.minute) + + # 问候 + greeting = self._get_greeting(now.hour, scene_ctx, agent_type) + + # 急迫语境:问候已包含时间,跳过 agent 前缀 + if scene_ctx.get("is_urgent"): + base = greeting + else: + base = self._apply_agent_style( + time_str, greeting, "", agent_type) + + # 办公型极简返回 + if agent_type == "办公" and not scene_ctx.get("is_urgent"): + return time_str + + # 互斥选择1个场景短句 + suffix = self._get_scene_suffix(now.hour, scene_ctx, agent_type) + + result = f"{base}" + if suffix: + sep = "。" if scene_ctx.get("is_urgent") else "," + result += sep + suffix + return result + + except Exception: + # ---- 二级:通用友好降级 ---- + try: + _, _, time_str = self._format_time_oral(now.hour, now.minute) + hour = now.hour + for (start, end, _, _, greetings, _) in _PERIODS: + if start <= hour < end: + return f"{greetings[0]}现在是{time_str}哦~" + return f"现在是{time_str}哦~" + except Exception: + # ---- 三级:极简硬编码兜底 ---- + h = now.hour + m = now.minute + display = h % 12 or 12 + m_str = "点整" if m == 0 else f"零{m}分" if m < 10 else f"{m}分" + return f"现在是{display}点{m_str}" + + # ================================================================== + # 公开接口 + # ================================================================== + + def get_reply_with_context(self, query_type: str = "all", + user_message: str = "", + user_context: dict | None = None, + agent_type: str = "通用") -> str: + """完整个性化回复入口(推荐使用) + + 自动整合所有维度信息生成回复: + 1. 口语化时间格式 + 2. 多轮对话检测 + 3. 六段场景问候 + 4. 工作日/周末/节假日 + 5. 急迫语境安抚 + 6. 记忆联动(时区/所在地/职业/偏好) + 7. 跨工具联动(天气/行程) + 8. Agent 风格适配 + 9. 三级兜底保护 + + 参数: + query_type: 查询类型(time/date/week/all/date_offset/等) + user_message: 用户原始消息(急迫检测 + 偏移计算用) + user_context: 用户上下文(可选),格式: + {"weather": {"desc":"晴","temp":"22"}, + "travel": {"type":"航班","time":"2026-05-07T10:00"}, + "remote_timezone": "Asia/Tokyo"} + agent_type: Agent类型,"通用"|"闲聊"|"办公"|"旅游"|"创作" + + 返回: + 经过所有规则处理的最终自然语言回复 + + 用法: + reply = tool.get_reply_with_context( + "time", "快迟到了现在几点", + user_context={"travel": {"type": "会议", "time": "..."}}, + agent_type="办公" + ) + """ + # 兜底值 + if agent_type not in _AGENT_STYLES: + agent_type = _DEFAULT_AGENT + + # 日期类查询保持原有逻辑不变 + if query_type in ("date", "date_offset", "week", "week_offset", + "lunar", "holiday", "timezone"): + return self.get_reply(query_type, user_message=user_message) + + # 时间和综合查询走新逻辑 + now = self._now() + if query_type in ("time", "all"): + return self._generate_with_fallback( + now, user_message, user_context, agent_type) + + # 其他类型 fallback 到旧方法 + return self.get_reply(query_type, user_message=user_message) + + def get_reply(self, query_type: str, user_message: str = "") -> str: + """统一回复入口 + + 参数: + query_type: 查询类型(time/date/date_offset/week/week_offset/lunar/holiday/timezone) + user_message: 用户原始消息 + + 返回: + 自然语言回复字符串 + """ + now = self._now() + + if query_type == "time": + _, _, time_str = self._format_time_oral(now.hour, now.minute) + greeting = self._contextual_greeting(now.hour) + return f"{greeting}现在是{time_str}哦~" + + if query_type in ("date", "date_offset"): + offset = _extract_day_offset(user_message, now) if (query_type == "date_offset" and user_message) else 0 + target_date = now.date() + timedelta(days=offset) + target_dt = datetime.combine(target_date, now.time()).replace(tzinfo=self._timezone) + year, month, day = target_dt.year, target_dt.month, target_dt.day + weekday = _WEEKDAY_NAMES[target_dt.weekday()] + + prefix_map = {-1: "昨天是", -2: "前天是", 1: "明天是", 2: "后天是"} + prefix = prefix_map.get(offset, "") + if not prefix: + if offset > 0: + prefix = f"{offset}天后是" + elif offset < 0: + prefix = f"{abs(offset)}天前是" + else: + prefix = "今天是" + + lunar_info = _solar_to_lunar(target_date) + holiday = _get_holiday_info(target_date, lunar_info) + holiday_text = "" + if holiday: + first = holiday.split("、")[0] + holiday_text = f",{_get_holiday_message(first)}" + + reply = f"{prefix}{year}年{month}月{day}日,{weekday}{holiday_text}" + if lunar_info.get("found"): + m_name = lunar_info["month_name"] + d_name = lunar_info["day_name"] + y_name = lunar_info["year_name"] + if m_name not in ("正月", "腊月") or d_name != "初一": + reply += f"(农历{y_name}年{m_name}{d_name})" + return reply + + if query_type in ("week", "week_offset"): + offset = _extract_day_offset(user_message, now) if (query_type == "week_offset" and user_message) else 0 + target_date = now.date() + timedelta(days=offset) + target_dt = datetime.combine(target_date, now.time()).replace(tzinfo=self._timezone) + weekday = _WEEKDAY_NAMES[target_dt.weekday()] + weekday_num = target_dt.weekday() + + prefix_map = {1: "明天是", 2: "后天是", -1: "昨天是"} + prefix = prefix_map.get(offset, "") + if not prefix: + if offset > 0: + prefix = f"{offset}天后是" + elif offset < 0: + prefix = f"{abs(offset)}天前是" + else: + prefix = "今天是" + + if weekday_num in _WEEKEND_DAYS: + return f"{prefix}{weekday}呢,好好享受周末时光吧~" + elif weekday_num == 4: + return f"{prefix}{weekday},马上就要周末啦,加油!" + return f"{prefix}{weekday}~" + + if query_type == "lunar": + lunar_info = _solar_to_lunar(now.date()) + if not lunar_info.get("found"): + return f"今天是{now.strftime('%Y-%m-%d')}," \ + "很抱歉暂时没有该日期的农历数据哦~" + year_name = lunar_info["year_name"] + month_name = lunar_info["month_name"] + day_name = lunar_info["day_name"] + holiday_name = _get_holiday_info(now.date(), lunar_info) + holiday_text = "" + if holiday_name: + holiday_names = holiday_name.split("、") + holiday_text = "," + _get_holiday_message(holiday_names[0]) + reply = f"今天是农历{year_name}年{month_name}{day_name}{holiday_text}" + if lunar_info.get("spring_date"): + spring = lunar_info["spring_date"] + reply += f"(今年春节是{spring.year}年{spring.month}月{spring.day}日)" + return reply + + if query_type == "holiday": + lunar_info = _solar_to_lunar(now.date()) + holiday_name = _get_holiday_info(now.date(), lunar_info) + year, month, day = now.year, now.month, now.day + weekday = _WEEKDAY_NAMES[now.weekday()] + if holiday_name: + holiday_names = holiday_name.split("、") + first = holiday_names[0] + msg = _get_holiday_message(first) + reply = f"今天是{year}年{month}月{day}日{weekday},{msg}" + if len(holiday_names) > 1: + reply += f"同时还是{holiday_name},今天可是个好日子!" + return reply + return f"今天是{year}年{month}月{day}日{weekday},今天不是法定节假日哦~" + + if query_type == "timezone": + cleaned = _clean_input(user_message) if user_message else "" + tz_name = self._detect_tz_city_name(cleaned) + if tz_name: + tz_id = (_CITY_TIMEZONE_MAP.get(tz_name) or + _TIMEZONE_KEYWORDS.get(tz_name + "时间")) + if tz_id and tz_id in available_timezones(): + tz_tool = TimeTool(timezone=tz_id) + tz_now = tz_tool._now() + _, _, tz_time_str = tz_tool._format_time_oral(tz_now.hour, tz_now.minute) + return f"{tz_name}现在是{tz_time_str}哦~" + return f"抱歉,暂时不支持查询「{tz_name}」的时区信息哦~" + + # 综合回复 + year, month, day = now.year, now.month, now.day + hour, minute = now.hour, now.minute + weekday = _WEEKDAY_NAMES[now.weekday()] + weekday_num = now.weekday() + + _, _, time_str = self._format_time_oral(hour, minute) + greeting = self._contextual_greeting(hour) + base = f"{greeting}现在是{year}年{month}月{day}日{weekday}{time_str}" + + lunar_info = _solar_to_lunar(now.date()) + holiday_name = _get_holiday_info(now.date(), lunar_info) + if holiday_name: + holiday_names = holiday_name.split("、") + base += f",今天是{holiday_name}" + base += "," + _get_holiday_message(holiday_names[0]) + + if weekday_num in _WEEKEND_DAYS and not holiday_name: + base += ",祝您周末愉快!" + elif weekday_num == 4: + base += ",明天就是周末啦,再坚持一下~" + return base + + def _contextual_greeting(self, hour: int) -> str: + """时段问候""" + for (start, end, _, _, greetings, _) in _PERIODS: + if start <= hour < end: + return greetings[0] + return "" + + @staticmethod + def _detect_tz_city_name(cleaned: str) -> str | None: + """时区城市名检测""" + for city in _CITY_TIMEZONE_MAP: + if city in cleaned: + return city + for keyword in _TIMEZONE_KEYWORDS: + if keyword in cleaned: + return keyword.replace("时间", "").replace("时区", "") + return None + + +# ============================================================================= +# 全局单例与对外接口 +# ============================================================================= + +_time_tool_instance: Optional[TimeTool] = None +_time_tool_timezone: str = "Asia/Shanghai" +_time_tool_agent_id: Optional[str] = None + + +def _get_time_tool(timezone: str = "Asia/Shanghai", + agent_id: str | None = None) -> TimeTool: + """获取 TimeTool 单例,时区/用户变更时重建""" + global _time_tool_instance, _time_tool_timezone, _time_tool_agent_id + if (_time_tool_instance is None + or timezone != _time_tool_timezone + or agent_id != _time_tool_agent_id): + _time_tool_instance = TimeTool(timezone=timezone, agent_id=agent_id) + _time_tool_timezone = timezone + _time_tool_agent_id = agent_id + return _time_tool_instance + + +def get_time_reply_enhanced(user_message: str, + timezone: str = "Asia/Shanghai", + agent_id: str | None = None, + user_context: dict | None = None, + agent_type: str = "通用") -> str: + """增强对外接口 —— 支持个性化上下文 + 多Agent风格 + + 参数: + user_message: 用户原始消息 + timezone: 时区(默认东八区) + agent_id: 用户标识(记忆系统读取用) + user_context: 用户上下文 {"weather":..., "travel":..., "remote_timezone":...} + agent_type: Agent类型,"通用"|"闲聊"|"办公"|"旅游"|"创作" + + 用法: + reply = get_time_reply_enhanced( + "现在几点", agent_type="办公", + user_context={"weather": {"desc": "雨", "temp": "12"}} + ) + """ + if not user_message: + return "" + cleaned = _clean_input(user_message) + if not cleaned: + return "" + query_type = _detect_query_type(cleaned) + tool = _get_time_tool(timezone=timezone, agent_id=agent_id) + + if query_type != "timezone": + detected_tz = _detect_timezone_from_message(cleaned) + if detected_tz and detected_tz != timezone: + tool = _get_time_tool(timezone=detected_tz, agent_id=agent_id) + + return tool.get_reply_with_context( + query_type, user_message=user_message, + user_context=user_context, agent_type=agent_type + ) + + +# ============================================================================= +# 直接运行验证(python -m app.utils.time_tool) +# ============================================================================= +if __name__ == "__main__": + import random + + print("=" * 74) + print(" TimeTool 回复优化版 全场景测试") + print("=" * 74) + + successful = 0 + total = 0 + + def test_one(label: str, fn, *args, **kw) -> str: + global successful, total + total += 1 + try: + result = fn(*args, **kw) + if result: + print(f"\n [{label}]") + print(f" {result}") + successful += 1 + else: + print(f"\n [FAIL] {label} -> 空回复") + except Exception as e: + print(f"\n [FAIL] {label} -> 异常: {e}") + return "" + + # ---- 基础时间查询 ---- + test_one("基础-现在几点了", get_time_reply_enhanced, "现在几点了") + test_one("基础-今天几号", get_time_reply_enhanced, "今天几号") + test_one("基础-今天周几", get_time_reply_enhanced, "今天星期几") + + # ---- 日期偏移 ---- + test_one("偏移-明天几号", get_time_reply_enhanced, "明天几号") + test_one("偏移-后天周几", get_time_reply_enhanced, "后天是星期几") + test_one("偏移-下周一", get_time_reply_enhanced, "下周一") + + # ---- 农历/节假日 ---- + test_one("农历-今天", get_time_reply_enhanced, "农历今天") + test_one("节假日-今天什么日子", get_time_reply_enhanced, "今天是什么日子") + + # ---- 时区 ---- + test_one("时区-东京时间", get_time_reply_enhanced, "现在东京时间几点") + + # ---- 多Agent风格对比(各自独立实例,避免多轮干扰)---- + print("\n" + "=" * 74) + print(" Agent 风格对比 (同一消息: '现在几点了')") + print("=" * 74) + for agent in ["通用", "闲聊", "办公", "旅游", "创作"]: + # 每个Agent用独立实例,展示各自风格 + t = TimeTool(timezone="Asia/Shanghai") + r = t.get_reply_with_context("time", "现在几点了", agent_type=agent) + total += 1 + if r: + print(f"\n [Agent-{agent}]") + print(f" {r}") + successful += 1 + else: + print(f"\n [FAIL] Agent-{agent} -> 空回复") + + # ---- 场景化上下文联动(每个场景独立实例)---- + def test_one_direct(label: str, result: str): + global successful, total + total += 1 + if result: + print(f"\n [{label}]") + print(f" {result}") + successful += 1 + else: + print(f"\n [FAIL] {label} -> 空回复") + + print("\n" + "=" * 74) + print(" 场景化联动测试") + print("=" * 74) + + t = TimeTool(timezone="Asia/Shanghai") + r = t.get_reply_with_context( + "time", "现在几点", agent_type="通用", + user_context={"weather": {"desc": "中雨", "temp": "12"}}) + test_one_direct("联动-下雨天", r) + + t2 = TimeTool(timezone="Asia/Shanghai") + r2 = t2.get_reply_with_context( + "time", "现在几点了", agent_type="通用", + user_context={"travel": { + "type": "航班", + "time": (datetime.now() + timedelta(hours=1, minutes=30) + ).isoformat() + }}) + test_one_direct("联动-有行程", r2) + + t3 = TimeTool(timezone="Asia/Shanghai") + r3 = t3.get_reply_with_context("time", "来不及了现在几点", agent_type="通用") + test_one_direct("联动-急迫语境", r3) + + # ---- 多轮对话模拟(共享同一实例)---- + print("\n" + "=" * 74) + print(" 多轮对话测试(同一实例连续3次查询)") + print("=" * 74) + multi_tool = TimeTool(timezone="Asia/Shanghai") + for i in range(3): + r = multi_tool.get_reply_with_context( + "time", "现在几点了", agent_type="通用") + test_one_direct(f"多轮-第{i+1}次", r) + if i < 2: + _time_module.sleep(0.05) + + # ---- 综合结果 ---- + print("\n" + "=" * 74) + print(f" 总计: {successful}/{total} 通过") + if successful == total: + print(" 全部测试通过!") + else: + print(f" 有 {total - successful} 个测试未通过") diff --git a/backend/app/utils/tool_executor.py b/backend/app/utils/tool_executor.py new file mode 100644 index 0000000..76214d0 --- /dev/null +++ b/backend/app/utils/tool_executor.py @@ -0,0 +1,327 @@ +""" +工具执行引擎 —— 批量执行工具调用并返回处理后的结果 + +功能: + 接收工具名+参数列表,批量执行并处理结果(过滤/聚合/精简), + 返回结构化的工具执行结果供 LLM 总结使用。 + +设计原则: + 1. 所有异常捕获,单个工具失败不影响其他工具 + 2. 所有工具执行后统一做结果精简处理 + 3. 支持本地工具和外部 API 工具混合调用 + 4. 异步执行,天然支持并发(但按序执行以保障顺序一致性) +""" + +from loguru import logger +from app.runtime.plugin.skill.executor import SkillExecutor +from app.utils.tool_result_processor import process_tool_result +from app.utils.tool_parameter_extractor import ToolParameterExtractor +from app.utils.intent_gateway import is_weather_query +from app.utils.time_tool import TimeTool +from app.utils.weather_tool import _weather_tool +from app.utils.local_handler import _extract_city +from app.utils.web_search_tool import search_web + + +_time_tool_instance = TimeTool(timezone="Asia/Shanghai") +_extractor = ToolParameterExtractor() + + +async def execute_tool_by_name(tool_name: str, args: dict) -> str: + """Tool Loop 专用:按工具名+参数执行单个工具,返回结果文本 + + 与 execute_single_tool 的区别: + - 返回纯文本字符串(不是 dict) + - 参数由 LLM 生成(不是正则提取) + - 自动处理工具不存在的情况 + """ + try: + if tool_name == "get_weather": + city = args.get("city", "") + date_str = args.get("date", args.get("date_str", "")) + if city: + try: + raw = await _weather_tool.get_reply(city, date_str) + return process_tool_result(tool_name, raw) + except Exception as e: + logger.warning(f"[ToolExecutor] 天气工具异常: {e}") + return f"获取天气信息失败: {e}" + + if tool_name == "get_current_time": + try: + raw = _time_tool_instance.get_reply(query_type="date", user_message="") + return process_tool_result(tool_name, raw) + except Exception as e: + logger.warning(f"[ToolExecutor] 时间工具异常: {e}") + return f"获取时间信息失败: {e}" + + if tool_name == "web_search": + query = args.get("query", "") + if query: + raw = await search_web(query) + return process_tool_result(tool_name, raw) + return "搜索查询为空" + + if tool_name == "calculate": + expression = args.get("expression", "") + if expression: + try: + allowed_names = { + "abs": abs, "round": round, "min": min, "max": max, + "sum": sum, "pow": pow, "len": len, + } + import math + for name in dir(math): + if not name.startswith("_"): + allowed_names[name] = getattr(math, name) + result = eval(expression, {"__builtins__": {}}, allowed_names) + return f"{expression} = {result}" + except Exception as e: + return f"计算错误: {e}" + return "计算表达式为空" + + if tool_name == "search": + query = args.get("query", "") + if query: + try: + from app.runtime.plugin.skill.executor import SkillExecutor + executor = SkillExecutor() + raw = await executor.execute(tool_name, args) + return process_tool_result(tool_name, raw) + except Exception as e: + return f"知识库搜索失败: {e}" + return "搜索查询为空" + + executor = SkillExecutor() + raw = await executor.execute(tool_name, args) + processed = process_tool_result(tool_name, raw) + return processed + + except Exception as e: + logger.warning(f"[ToolExecutor] execute_tool_by_name({tool_name}) 异常: {e}") + return f"工具 '{tool_name}' 执行出错: {e}" + + +async def execute_tool_chain( + user_query: str, + agent_id: str | None = None, + external_search_results: str | None = None, +) -> list[dict]: + """根据用户查询匹配工具并批量执行 + + 完整流程: + 1. 用 get_matched_tools 匹配场景和工具列表 + 2. 对每个工具用 ToolParameterExtractor 提取参数 + 3. 对每个工具调用 execute_single_tool 执行 + 4. 对每个结果用 process_tool_result 精简 + 5. 返回结构化结果列表 + + 参数: + user_query: 用户原始提问 + agent_id: 代理 ID(用于上下文隔离) + + 返回: + [{"tool_name": "...", "args": {...}, "result": "...", "success": bool}] + + 用法: + results = await execute_tool_chain("北京明天天气怎么样") + if results: + print(results[0]["result"]) # 精简后的天气文本 + """ + from app.utils.tool_lazy_loader import get_matched_tools + + results: list[dict] = [] + + # 1. 获取匹配的工具列表 + tools = get_matched_tools(user_query) + if not tools: + return results + + # 2. 提取工具名集合 + tool_names: set[str] = set() + for tool_def in tools: + fn = tool_def.get("function", {}) + name = fn.get("name", "") + if name: + tool_names.add(name) + + # 3. 如果同时有时间工具和搜索工具,先执行时间工具获取当前日期 + # 用于优化搜索查询词(如5月问软考→搜索"下半年") + current_date_info: str | None = None + if "get_current_time" in tool_names and "web_search" in tool_names: + try: + time_args = _extractor.extract("get_current_time", user_query) + time_result = _time_tool_instance.get_reply( + query_type="date", + user_message=user_query, + ) + if time_result: + current_date_info = time_result + logger.info(f"[ToolExecutor] 跨工具联动: 获取当前日期 → {current_date_info[:50]}") + except Exception as e: + logger.debug(f"[ToolExecutor] 跨工具联动时间获取失败: {e}") + + # 4. 对每个工具提取参数并执行 + executor = SkillExecutor() + for tool_name in sorted(tool_names): + args = _extractor.extract(tool_name, user_query) + + # 如果是搜索工具且有外部搜索结果,直接使用 + if tool_name == "web_search" and external_search_results: + processed = process_tool_result(tool_name, external_search_results) + logger.info(f"[ToolExecutor] web_search (外部浏览器结果) → {len(processed)} 字符") + results.append({ + "tool_name": tool_name, + "args": args, + "result": processed, + "success": True, + }) + continue + + result = await execute_single_tool(tool_name, args, executor, agent_id, user_query) + results.append(result) + + return results + + +async def execute_single_tool( + tool_name: str, + args: dict, + executor: SkillExecutor | None = None, + agent_id: str | None = None, + user_query: str = "", +) -> dict: + """执行单个工具并返回处理结果 + + 参数: + tool_name: 工具名称 + args: 从用户查询中提取的参数 + executor: SkillExecutor 实例(可选,复用避免重复创建) + agent_id: 代理 ID + user_query: 用户原始提问(用于本地工具的上下文感知) + + 返回: + {"tool_name": "...", "args": {...}, "result": "...", "success": bool} + """ + if executor is None: + executor = SkillExecutor() + + try: + # ---------- 天气工具:优先走本地快速路径 ---------- + if tool_name == "get_weather" and "city" in args: + city = args.get("city", "") + date_str = args.get("date_str", "") + try: + raw = await _weather_tool.get_reply(city, date_str) + except Exception as e: + logger.warning(f"[ToolExecutor] 天气快速路径异常,降级到通用执行: {e}") + raw = await executor.execute(tool_name, args, agent_id=agent_id) + processed = process_tool_result(tool_name, raw) + logger.info(f"[ToolExecutor] get_weather → {len(processed)} 字符") + return { + "tool_name": tool_name, + "args": args, + "result": processed, + "success": True, + } + + # ---------- 时间工具:优先走本地快速路径 ---------- + if tool_name == "get_current_time": + try: + date_str = args.get("date_str", "") + if date_str: + raw = _time_tool_instance.get_reply_with_context( + query_type="date_offset", + user_message=user_query, + agent_type="通用", + ) + else: + raw = _time_tool_instance.get_reply( + query_type="date", + user_message=user_query, + ) + except Exception as e: + logger.warning(f"[ToolExecutor] 时间快速路径异常,降级到通用执行: {e}") + raw = await executor.execute(tool_name, args, agent_id=agent_id) + processed = process_tool_result(tool_name, raw) + logger.info(f"[ToolExecutor] get_current_time → {len(processed)} 字符") + return { + "tool_name": tool_name, + "args": args, + "result": processed, + "success": True, + } + + # ---------- 搜索工具:走 DuckDuckGo 搜索 ---------- + if tool_name == "web_search": + query = args.get("query", user_query) + raw = await search_web(query) + processed = process_tool_result(tool_name, raw) + logger.info(f"[ToolExecutor] web_search → {len(processed)} 字符") + return { + "tool_name": tool_name, + "args": args, + "result": processed, + "success": True, + } + + # ---------- 通用路径:走 SkillExecutor ---------- + raw = await executor.execute(tool_name, args, agent_id=agent_id) + processed = process_tool_result(tool_name, raw) + logger.info(f"[ToolExecutor] {tool_name} → 原始 {len(raw)} 字符 → 精简 {len(processed)} 字符") + return { + "tool_name": tool_name, + "args": args, + "result": processed, + "success": True, + } + + except Exception as e: + logger.warning(f"[ToolExecutor] {tool_name} 执行异常: {e}") + return { + "tool_name": tool_name, + "args": args, + "result": f"工具执行出错: {e}", + "success": False, + } + + +def build_tool_summary(user_query: str, tool_results: list[dict]) -> str: + """构建发给 LLM 的工具结果总结提示词 + + 将工具执行结果组织成自然语言提示,让 LLM 进行总结回答。 + + 参数: + user_query: 用户原始提问 + tool_results: execute_tool_chain 返回的结构化结果列表 + + 返回: + 可发送给 LLM 的提示词文本 + """ + if not tool_results: + return user_query + + success_results = [r for r in tool_results if r["success"]] + failed_results = [r for r in tool_results if not r["success"]] + + parts = [f"用户问:{user_query}"] + + if success_results: + parts.append("\n已获取到以下信息:") + for r in success_results: + result_text = r["result"] + if len(result_text) > 300: + result_text = result_text[:300] + "…" + parts.append(f"- [{r['tool_name']}] {result_text}") + + if failed_results: + parts.append("\n以下工具执行失败:") + for r in failed_results: + parts.append(f"- {r['tool_name']}: {r['result']}") + + parts.append("\n请根据以上信息,用自然、友好的语言总结回答用户的问题。") + parts.append("如果信息不足或工具失败,告诉用户暂时无法获取完整信息。") + parts.append("注意区分「报名时间」和「考试时间」,用户问的是考试/事件日期时,不要用报名时间代替。") + parts.append("如果搜索结果中的日期已经过去,请说明该日期已过,并尝试告知下一次的时间(如有信息)。") + + return "\n".join(parts) \ No newline at end of file diff --git a/backend/app/utils/tool_lazy_loader.py b/backend/app/utils/tool_lazy_loader.py new file mode 100644 index 0000000..046d860 --- /dev/null +++ b/backend/app/utils/tool_lazy_loader.py @@ -0,0 +1,424 @@ +""" +工具懒加载模块 —— 按需注入工具定义,杜绝全量预加载 + +功能: + 根据用户消息的关键词匹配对应场景,仅返回该场景需要的工具定义, + 从根源解决工具乱触发、token 浪费的问题。 + +核心流程: + 1. 用户消息 → 关键词匹配场景 → 取对应工具名集合 + 2. 去重后从 SkillRegistry 获取 OpenAI Function Calling 格式定义 + 3. 无匹配场景返回空列表(等效不注入任何工具) + 4. 异常时降级到全量工具注入 + +增强特性: + - 新增实时数据/知识边界/比较评价/事实特异性场景 + - 支持八维度搜索意图评分器触发的隐式搜索场景 + - 场景覆盖从6个扩展到12个 + +设计原则: + 1. 仅修改工具注入逻辑,不改动任何工具的实现代码 + 2. GENERAL_CHAT 请求绝不注入工具 + 3. 保留 SkillRegistry.get_openai_tools() 作为异常降级兜底 + 4. 场景与工具名解耦:场景只存工具名,定义从注册表动态获取 +""" + +from loguru import logger + + +# ============================================================================= +# 场景-工具映射配置 +# +# 每个场景包含: +# - keywords: 触发该场景的关键词集合(命中任一即匹配) +# - tools: 该场景下需要注入的工具名称列表 +# +# 扩展方式:新增场景只需在此添加一行配置,无需改动匹配逻辑 +# ============================================================================= + +SCENE_TOOL_MAP: dict[str, dict] = { + # ----- 天气场景 ----- + "weather": { + "keywords": { + "天气", "下雨", "下雪", "刮风", "台风", "雾霾", "冰雹", + "气温", "温度", "湿度", "风力", "空气质量", "pm2.5", + "防晒", "带伞", "紫外线", "降雨", "降水", "阴天", "晴天", "多云", + "冷不冷", "热不热", "穿什么衣服", "穿衣指数", "冷暖", "预报", + }, + "tools": ["get_weather", "web_search"], + }, + + # ----- 搜索场景 ----- + "search": { + "keywords": { + "搜索", "查找", "搜一下", "帮我搜", "帮我查", + "帮我找", "帮我看看", "查资料", "检索", "搜寻", + "百度", "谷歌", "百度一下", "搜一搜", + }, + "tools": ["search", "web_search"], + }, + + # ----- 旅游场景 ----- + "travel": { + "keywords": { + "旅游", "旅行", "度假", "景点", "攻略", "游记", + "行程", "路线", "导航", "怎么去", "怎么走", "如何去", + "酒店", "民宿", "机票", "火车票", "订票", "订酒店", + "规划", "安排行程", "出行计划", "自驾", "跟团", + "周边游", "一日游", "几日游", "自由行", "签证", + }, + "tools": ["get_weather", "search", "web_search"], + }, + + # ----- 计算场景 ----- + "calculate": { + "keywords": { + "计算", "算一下", "帮我算", "等于多少", "等于几", + "得多少", "得几", "是多少", "答案是", "换算", + "求", "求解", + }, + "tools": ["calculate"], + }, + + # ----- 时间场景 ----- + "time": { + "keywords": {"几点", "几号", "几时", "周几", "星期几", "时间", "日期", "日历", + "几月", "月份"}, + "tools": ["get_current_time"], + }, + + # ----- 倒计时场景 ----- + "countdown": { + "keywords": {"距离", "几天", "还有几天", "还剩几天", "还有多久", "倒计时"}, + "tools": ["get_current_time", "web_search"], + }, + + # ----- Agent 转交场景 ----- + "agent": { + "keywords": { + "转交", "转接", "切换", "换个agent", "找其他agent", + }, + "tools": ["transfer_to_agent"], + }, + + # ----- 金融实时场景 ----- + "finance_realtime": { + "keywords": { + "股价", "股票", "行情", "涨幅", "跌幅", "市值", + "基金", "比特币", "加密货币", "汇率", "换汇", "外汇", + "油价", "汽油价", "柴油价", "金价", "黄金价", + "房价", "二手房", "均价", "成交价", + }, + "tools": ["web_search"], + }, + + # ----- 新闻热点场景 ----- + "news": { + "keywords": { + "新闻", "热点", "头条", "爆料", "事件", "事故", + "最新消息", "最新动态", "最新公告", "最新通知", + }, + "tools": ["web_search"], + }, + + # ----- 考试教育场景 ----- + "exam": { + "keywords": { + "考试", "报名", "准考证", "成绩", "录取", "分数线", + "软考", "考研", "高考", "中考", "国考", "省考", + "事业编", "公务员", "选调", "教资", "法考", "注会", + "一建", "二建", "招生", "录取线", "合格线", + }, + "tools": ["web_search"], + }, + + # ----- 体育赛事场景 ----- + "sports": { + "keywords": { + "比赛", "赛事", "比分", "积分", "排名", "赛程", + "对阵", "世界杯", "奥运会", "欧冠", "NBA", "亚运会", + "世博会", "欧洲杯", "亚洲杯", "全运会", + }, + "tools": ["web_search"], + }, + + # ----- 比较评价场景 ----- + "comparison": { + "keywords": { + "哪个好", "怎么选", "对比", "比较", "区别", "差异", + "性价比", "划算", "值得买", "买哪个", "选哪个", + "排行", "排名", "榜单", "口碑", "评测", "测评", + "优缺点", "优劣", + }, + "tools": ["web_search"], + }, + + # ----- 事实特异性场景 ----- + "fact_specific": { + "keywords": { + "分数线", "录取线", "报名费", "学费", "票价", "门票", + "营业时间", "开放时间", "官网", "下载地址", + "名额", "招生人数", "招聘人数", "限购", + "联系方式", "客服", "咨询电话", + }, + "tools": ["web_search"], + }, + + # ----- 民生通知场景 ----- + "civil_notification": { + "keywords": { + "限行", "限号", "尾号限行", "单双号", + "停水", "停电", "停气", "检修", "维护通知", + "政策", "法规", "规定", "新规", "出台", "实施", + }, + "tools": ["web_search"], + }, + + # ----- 出入境场景 ----- + "immigration": { + "keywords": { + "签证", "护照", "出入境", "海关", "入境政策", + }, + "tools": ["web_search"], + }, + + # ----- 招聘求职场景 ----- + "job": { + "keywords": { + "招聘", "求职", "岗位", "薪资", "待遇", "offer", + }, + "tools": ["web_search"], + }, + + # ----- 医疗健康场景 ----- + "medical": { + "keywords": { + "疫苗", "挂号", "核酸检测", "门诊", "就诊", "医保", + "医院", "专家", "排班", + }, + "tools": ["web_search"], + }, + + # ----- 产品发布场景 ----- + "product_launch": { + "keywords": { + "发布", "推出", "上市", "开售", "预售", "发售", "新品", + "iPhone", "iPad", "MacBook", "华为", "小米", "三星", + "特斯拉", "比亚迪", "蔚来", + }, + "tools": ["web_search"], + }, + + # ----- AI产品场景 ----- + "ai_product": { + "keywords": { + "GPT", "Claude", "Gemini", "Llama", "Sora", "Copilot", + "DeepSeek", "Kimi", "豆包", "通义", "文心", "千问", "智谱", + "ChatGPT", "OpenAI", "Anthropic", + }, + "tools": ["web_search"], + }, + + # ----- 影视娱乐场景 ----- + "entertainment": { + "keywords": { + "上映", "票房", "评分", "豆瓣", "IMDb", + "好看", "推荐电影", "推荐剧", "有什么好看", + "诺贝尔", "奥斯卡", "格莱美", + }, + "tools": ["web_search"], + }, +} + + +def _match_scenes(user_message: str) -> list[str]: + """根据用户消息匹配命中的场景集合 + + 遍历所有场景的关键词集合,命中任一关键词即认为该场景匹配。 + 一条消息可能同时命中多个场景,返回所有匹配场景的列表。 + + 参数: + user_message: 清洗后的用户消息文本 + + 返回: + 匹配的场景名称列表 + """ + matched_scenes: list[str] = [] + for scene_name, scene_config in SCENE_TOOL_MAP.items(): + for keyword in scene_config["keywords"]: + if keyword in user_message: + matched_scenes.append(scene_name) + break + return matched_scenes + + +def _resolve_tool_names(matched_scenes: list[str]) -> list[str]: + """从匹配的场景中提取工具名,去重后返回 + + 参数: + matched_scenes: 匹配的场景名称列表 + + 返回: + 去重后的工具名称列表 + """ + seen: set[str] = set() + tool_names: list[str] = [] + for scene_name in matched_scenes: + for tool_name in SCENE_TOOL_MAP[scene_name]["tools"]: + if tool_name not in seen: + seen.add(tool_name) + tool_names.append(tool_name) + return tool_names + + +def get_matched_tools(user_message: str) -> list[dict]: + """核心函数:根据用户消息返回匹配的工具定义列表 + + 完整流程: + 1. 清洗输入(去空格、去问号) + 2. 关键词匹配场景 + 3. 获取工具名(去重) + 4. 从 SkillRegistry 获取 OpenAI Function Calling 格式定义 + 5. 返回匹配的工具列表 + + 参数: + user_message: 用户输入的原始消息文本 + + 返回: + OpenAI Function Calling 格式的工具定义列表。 + 无匹配场景时返回空列表 []。 + + 异常安全: + 本函数内部已妥善处理异常,不会向外抛出。 + """ + if not user_message: + return [] + + clean_msg = user_message.replace("?", "").replace("?", "").replace(" ", "").replace(" ", "") + + if not clean_msg: + return [] + + matched_scenes = _match_scenes(clean_msg) + + if not matched_scenes: + return [] + + tool_names = _resolve_tool_names(matched_scenes) + + try: + from app.runtime.plugin.skill.registry import SkillRegistry + + tools: list[dict] = [] + for tool_name in tool_names: + skill_data = SkillRegistry.get_skill(tool_name) + if skill_data is None: + logger.debug(f"[ToolLazyLoader] 工具 '{tool_name}' 未注册,跳过") + continue + from app.runtime.plugin.skill.base import SkillDefinition + skill_def = SkillDefinition( + name=skill_data.get("name", tool_name), + description=skill_data.get("description", ""), + category=skill_data.get("category", "general"), + parameters=skill_data.get("parameters", {}), + is_active=skill_data.get("is_active", True), + is_builtin=skill_data.get("is_builtin", False), + handler_name=skill_data.get("handler_name"), + prompt_template=skill_data.get("prompt_template"), + tags=skill_data.get("tags", []), + ) + tools.append(skill_def.to_openai_tool()) + + logger.info( + f"[ToolLazyLoader] 匹配场景: {matched_scenes}, " + f"注入工具: {[t['function']['name'] for t in tools]}" + ) + return tools + + except Exception as e: + logger.warning(f"[ToolLazyLoader] 懒加载异常,降级到全量注入: {e}") + try: + from app.runtime.plugin.skill.registry import SkillRegistry + return SkillRegistry.get_openai_tools() + except Exception as fallback_error: + logger.error(f"[ToolLazyLoader] 全量降级也失败: {fallback_error}") + return [] + + +# ============================================================================= +# 直接运行验证(python -m app.utils.tool_lazy_loader) +# ============================================================================= +if __name__ == "__main__": + test_cases = [ + # (用户消息, 期望匹配的场景列表, 期望注入的工具名列表) + ("今天北京天气怎么样", ["weather"], ["get_weather", "web_search"]), + ("帮我搜索一下Python教程", ["search"], ["search", "web_search"]), + ("推荐一个旅游景点", ["travel"], ["get_weather", "search", "web_search"]), + ("3加5等于多少", ["calculate"], ["calculate"]), + ("现在几点了", ["time"], ["get_current_time"]), + ("你好,请介绍一下你自己", [], []), + ("给我写一段朋友圈文案", [], []), + ("帮我查一下明天上海的温度", ["weather", "search"], ["get_weather", "web_search", "search"]), + ("我想去北京旅游,帮我规划一下行程", ["travel"], ["get_weather", "search", "web_search"]), + ("计算一下 100*200", ["calculate"], ["calculate"]), + ("今天天气不错,适合出去玩", ["weather"], ["get_weather", "web_search"]), + ("距离河北软考还有几天", ["countdown"], ["get_current_time", "web_search"]), + ("距离五一还有多久", ["countdown"], ["get_current_time", "web_search"]), + ("今天几号", ["time"], ["get_current_time"]), + ("明天星期几", ["time"], ["get_current_time"]), + + # 新增场景测试 + ("特斯拉股价多少", ["finance_realtime", "product_launch"], ["web_search"]), + ("最近有什么新闻", ["news"], ["web_search"]), + ("今年考研什么时候报名", ["exam"], ["web_search"]), + ("NBA总决赛比分", ["sports"], ["web_search"]), + ("iPhone 16和华为Mate70哪个好", ["comparison", "product_launch"], ["web_search"]), + ("清华录取分数线多少", ["exam", "fact_specific"], ["web_search"]), + ("今天油价多少", ["finance_realtime"], ["web_search"]), + ("北京今天限行尾号", ["civil_notification"], ["web_search"]), + ("GPT-5什么时候发布", ["ai_product"], ["web_search"]), + ("最近有什么好看的电影", ["entertainment"], ["web_search"]), + ("签证办理流程", ["immigration", "travel"], ["web_search"]), + ("最近有没有招聘会", ["job"], ["web_search"]), + ("北京协和医院挂号", ["medical"], ["web_search"]), + ] + + print("=" * 80) + print(" ToolLazyLoader 场景匹配 测试结果(增强版)") + print("=" * 80) + print() + + passed = 0 + failed = 0 + + for msg, expected_scenes, expected_tools in test_cases: + display = msg if msg else "(空消息)" + clean = msg.replace("?", "").replace("?", "").replace(" ", "").replace(" ", "") + actual_scenes = _match_scenes(clean) + tool_names = _resolve_tool_names(actual_scenes) + + scene_ok = set(expected_scenes).issubset(set(actual_scenes)) + tools_ok = set(expected_tools).issubset(set(tool_names)) + + if scene_ok and tools_ok: + status = "PASS" + passed += 1 + else: + status = "FAIL" + failed += 1 + + print(f" [{status}] {display:45}") + if not scene_ok: + missing = set(expected_scenes) - set(actual_scenes) + print(f" 场景: 期望包含 {expected_scenes}, 实际 {actual_scenes}, 缺少 {missing}") + if not tools_ok: + missing = set(expected_tools) - set(tool_names) + print(f" 工具: 期望包含 {expected_tools}, 实际 {tool_names}, 缺少 {missing}") + + print() + print(f" 通过: {passed} 失败: {failed} 总计: {len(test_cases)}") + + if failed == 0: + print("\n 全部测试通过!") + else: + print(f"\n 有 {failed} 个测试未通过,需要检查配置") diff --git a/backend/app/utils/tool_parameter_extractor.py b/backend/app/utils/tool_parameter_extractor.py new file mode 100644 index 0000000..37614fe --- /dev/null +++ b/backend/app/utils/tool_parameter_extractor.py @@ -0,0 +1,291 @@ +""" +工具参数提取器 —— 从用户自然语言查询中智能提取工具所需参数 + +功能: + 对用户的自然语言查询进行规则匹配,自动提取工具所需参数(城市、日期、搜索关键词等), + 无需依赖大模型,纯正则+规则树实现,零延迟。 + +设计原则: + 1. 纯正则 + 规则树,零 IO、零网络、零大模型调用 + 2. 每种工具一个专属提取方法,针对性处理 + 3. 提取失败返回空字段,不抛异常,不中断流程 + 4. 兼容现有的 city_pattern / date_pattern / weekday_pattern + +用法: + extractor = ToolParameterExtractor() + args = extractor.extract("get_weather", "北京明天天气怎么样") + # 返回 {"city": "北京", "date": "明天"} +""" + +import re +from datetime import datetime, timedelta + + +class ToolParameterExtractor: + """工具参数提取器 —— 纯规则引擎,按工具名分发专属提取方法""" + + def __init__(self): + # ---------- 城市名正则 ---------- + self.city_pattern = re.compile( + r"(北京|上海|广州|深圳|杭州|成都|武汉|西安|南京|重庆|天津|" + r"苏州|长沙|郑州|济南|青岛|大连|厦门|福州|昆明|贵阳|南宁|" + r"海口|三亚|哈尔滨|长春|沈阳|乌鲁木齐|拉萨|兰州|银川|西宁|" + r"呼和浩特|太原|石家庄|合肥|南昌|东莞|佛山|无锡|宁波|温州|" + r"徐州|珠海|惠州|中山|烟台|威海)" + ) + + # 城市单字别名(前后不能有中文字符) + self.city_alias_pattern = re.compile( + r"(? dict: + """根据工具名提取参数 + + 参数: + tool_name: 工具名称 + user_query: 用户原始提问 + + 返回: + 参数字典,若无匹配参数则返回空 dict + """ + if tool_name == "get_weather": + return self._extract_weather_args(user_query) + elif tool_name == "get_current_time": + return self._extract_time_args(user_query) + elif tool_name in ("search", "web_search"): + return self._extract_search_args(user_query) + elif tool_name == "calculate": + return {} + else: + return {} + + # ========================================================================= + # 通用提取方法 + # ========================================================================= + + def extract_city(self, user_query: str) -> str | None: + matched = self.city_pattern.findall(user_query) + if matched: + return matched[0] + alias_matched = self.city_alias_pattern.findall(user_query) + if alias_matched: + city = alias_matched[0] + return self.city_alias.get(city, city) + return None + + def extract_date_text(self, user_query: str) -> str | None: + """从查询中提取日期文本(明天/下周一/5.1号等),返回可用于 _weather_tool 的日期字符串""" + # 绝对日期 + match = self.date_pattern.search(user_query) + if match: + return match.group(1) + # 相对日期 + match = self.relative_day_pattern.search(user_query) + if match: + return match.group(1) + # 星期偏移 + match = self.weekday_pattern.search(user_query) + if match: + return match.group(1) + # 下周x + match = self.next_week_pattern.search(user_query) + if match: + return f"下周{match.group(2)}" + # 天数偏移 + match = self.day_offset_pattern.search(user_query) + if match: + return f"{match.group(1)}天后" + match = self.day_before_pattern.search(user_query) + if match: + return f"{match.group(1)}天前" + return None + + # ========================================================================= + # 专属提取方法 + # ========================================================================= + + def _extract_weather_args(self, user_query: str) -> dict: + """提取天气工具参数""" + args: dict = {} + city = self.extract_city(user_query) + if city: + args["city"] = city + date_text = self.extract_date_text(user_query) + if date_text: + args["date_str"] = date_text + return args + + def _extract_time_args(self, user_query: str) -> dict: + """提取时间工具参数""" + args: dict = {} + date_text = self.extract_date_text(user_query) + if date_text: + args["date_str"] = date_text + + # 时间偏移 + match = re.search(r"(\d+)[个]*(?:小时|钟头)[后前]", user_query) + if match: + val = int(match.group(1)) + if "前" in match.group(0): + val = -val + args["hour_offset"] = val + return args + + def _extract_search_args(self, user_query: str) -> dict: + """提取搜索工具参数,自动清除查询前缀和倒计时词,并添加时间上下文""" + args: dict = {} + # 去除"搜索"等前缀 + query = self.search_stop_words.sub("", user_query).strip() + # 去除"距离/还有几天"等倒计时词 + query = self.countdown_stop_words.sub("", query).strip() + # 尾随的"还有几天"/"还剩几天" + query = re.sub(r"(还有几天|还剩几天|剩下几天|还有多久)$", "", query).strip() + if not query: + query = user_query.strip() + # 如果查询词太短(<=3字),补充"时间"后缀以获得更好的搜索结果 + if len(query) <= 3: + query = query + " 时间" + # 添加时间上下文(年份+上半年/下半年推断) + query = self._enrich_search_query_with_time_context(query) + args["query"] = query + return args + + def _enrich_search_query_with_time_context(self, query: str) -> str: + """为搜索查询添加时间上下文 + + 当查询涉及周期性事件(考试/赛事等)且没有明确年份/上下半年时, + 自动补充当前年份和合理的上下半年推断。 + + 推断规则: + - 1-4月 → 搜索"上半年"(上半年考试通常5-6月举行) + - 5-8月 → 搜索"下半年"(上半年已过,下半年通常11月举行) + - 9-12月 → 搜索"下半年"(下半年考试通常11月举行) + """ + has_year = bool(re.search(r"20\d{2}年", query)) + has_half = bool(re.search(r"上半年|下半年", query)) + + if has_year and has_half: + return query + + periodic_event_patterns = [ + (re.compile(r"软考"), "考试"), + (re.compile(r"考研"), "考试"), + (re.compile(r"高考"), "考试"), + (re.compile(r"中考"), "考试"), + (re.compile(r"国考"), "考试"), + (re.compile(r"省考"), "考试"), + (re.compile(r"考公"), "考试"), + (re.compile(r"事业编"), "考试"), + (re.compile(r"教资|教师资格"), "考试"), + (re.compile(r"法考"), "考试"), + (re.compile(r"注会"), "考试"), + (re.compile(r"一建|二建"), "考试"), + (re.compile(r"公务员"), "考试"), + (re.compile(r"选调"), "考试"), + (re.compile(r"世界杯"), "赛事"), + (re.compile(r"奥运会"), "赛事"), + (re.compile(r"亚运会"), "赛事"), + (re.compile(r"欧冠"), "赛事"), + (re.compile(r"欧洲杯"), "赛事"), + (re.compile(r"亚洲杯"), "赛事"), + (re.compile(r"全运会"), "赛事"), + (re.compile(r"世博会"), "展会"), + (re.compile(r"进博会"), "展会"), + (re.compile(r"广交会"), "展会"), + (re.compile(r"双十一|618"), "购物节"), + (re.compile(r"春运"), "民生"), + (re.compile(r"秋招|春招"), "招聘"), + (re.compile(r"报名"), "报名"), + (re.compile(r"录取"), "录取"), + (re.compile(r"分数线"), "分数"), + ] + + matched_event = None + for pattern, event_type in periodic_event_patterns: + if pattern.search(query): + matched_event = event_type + break + + if not matched_event: + return query + + # 清洗疑问词和冗余词,提取核心搜索词 + core_query = query + core_query = re.sub(r"^今天|^现在|^当前|^目前|^今年", "", core_query) + core_query = re.sub(r"什么时候|几号|几时|哪天|哪一天|是哪天|是几号", "", core_query) + core_query = re.sub(r"还有几天|还剩几天|还有多久|还差几天$", "", core_query) + core_query = re.sub(r"多少|怎么样|好不好|有没有|是否", "", core_query) + core_query = core_query.strip() + + if not core_query: + core_query = query + + now = datetime.now() + year = now.year + month = now.month + + enriched = core_query + if not has_year: + enriched = f"{year}年" + enriched + + if not has_half and matched_event in ("考试", "报名", "录取", "分数", "招聘"): + # 高考/中考固定在6月举行,始终搜索上半年 + first_half_only = bool(re.search(r"高考|中考", core_query)) + if first_half_only: + enriched = enriched + "上半年" + elif month >= 5 and month <= 8: + enriched = enriched + "下半年" + elif month >= 9: + enriched = enriched + "下半年" + else: + enriched = enriched + "上半年" + + if matched_event in ("考试", "报名", "录取", "分数") and "时间" not in enriched: + enriched = enriched + "时间" + + # 优化顺序:将"下半年/上半年"移到事件名后面、时间前面 + enriched = re.sub( + r"^(20\d{2}年)(.+?)(上半年|下半年)(时间)$", + r"\1\3\2\4", + enriched + ) + + return enriched \ No newline at end of file diff --git a/backend/app/utils/tool_result_processor.py b/backend/app/utils/tool_result_processor.py new file mode 100644 index 0000000..2345555 --- /dev/null +++ b/backend/app/utils/tool_result_processor.py @@ -0,0 +1,518 @@ +""" +工具结果处理器 —— 在工具原始结果和大模型之间新增一层程序化过滤聚合 + +功能: + 对每个工具的原始返回结果进行过滤、聚合、精简,只保留核心有效信息, + 去除冗余 JSON 结构、无关字段、重复内容,降低 37% 以上的 token 消耗。 + +设计原则: + 1. 每个工具一个专属处理函数,针对性过滤无关字段 + 2. 未定义专属处理器的工具走通用兜底,自动提取核心信息 + 3. 绝对不丢失用户需要的核心内容,仅过滤冗余数据 + 4. 处理异常时降级返回原始结果,不影响用户体验 + 5. 纯函数设计,零副作用,极低延迟(毫秒级) +""" + +import json +import re +from loguru import logger + + +# ============================================================================= +# 核心接口:process_tool_result +# ============================================================================= + +def process_tool_result(tool_name: str, raw_result: str) -> str: + """工具结果处理入口 —— 根据工具名分发到对应的处理函数 + + 流程: + 1. 在 TOOL_PROCESSORS 中查找该工具的专属处理器 + 2. 找到 → 调用专属处理函数,返回精简后的结果 + 3. 未找到 → 调用通用兜底处理器 _process_generic + 4. 异常 → 降级返回原始结果,确保对话不中断 + + 参数: + tool_name: 工具名称,如 "get_weather"、"get_current_time" + raw_result: 工具返回的原始结果文本(通常是 JSON 字符串或自然语言) + + 返回: + 精简后的核心信息文本,供大模型生成最终回复 + + 用法: + processed = process_tool_result("get_weather", '{"city":"北京",...}') + """ + if not raw_result: + return "" + + processor = TOOL_PROCESSORS.get(tool_name, _process_generic) + try: + processed = processor(raw_result) + # 确保处理后的结果非空,空结果降级为原始结果 + if not processed or not processed.strip(): + logger.debug(f"[ResultProcessor] {tool_name} 处理结果为空,降级到原始结果") + return raw_result + return processed + except Exception as e: + logger.warning(f"[ResultProcessor] {tool_name} 处理异常,降级到原始结果: {e}") + return raw_result + + +# ============================================================================= +# 专属处理器 +# ============================================================================= + +def _process_weather_result(raw: str) -> str: + """天气结果处理器 —— 按日期类型智能精简 + + 原始结果包含(来自 weather_tool.py + Open-Meteo API): + 实时天气:city, date, weather, temp_min/max, wind_scale/dir, + humidity, precip_prob, formatted + 预报天气:city, date, weather, temp_min/max, wind_scale/dir, + precip_prob, day_offset, formatted + + 精简规则: + 实时天气:保留 formatted(含完整建议)或组装核心字段 + 预报天气:保留 formatted(含日期 + 预报 + 建议) + 兜底话术:保留原样,确保不丢信息 + + token 节省估算:原始 ~180 tokens → 处理后 ~45 tokens(节省 75%) + """ + # 尝试解析 JSON + data = _safe_parse_json(raw) + if data is None: + # 非 JSON 格式(如已经是自然语言或兜底话术),直接返回 + return _strip_json_wrapper(raw) + + # 优先取 formatted 字段(weather_tool 已生成口语化回复) + formatted = data.get("formatted", "") + if formatted: + return formatted + + # 兜底:从原始字段组装精简结果 + city = data.get("city", "") + fore_days = data.get("forecast_days", []) + + if fore_days: + # 多日数据 → 只保留每天的核心字段 + parts = [f"{city}" if city else ""] + for day in fore_days[:3]: # 最多3天 + date = day.get("date", "")[-5:] # MM-DD + w = day.get("weather", "") + t = f"{day.get('temp_min', '')}~{day.get('temp_max', '')}℃" + parts.append(f"{date} {w} {t}") + return ";".join(parts) + else: + # 单日数据 + weather = data.get("weather", "") + temp_min = data.get("temp_min", "") + temp_max = data.get("temp_max", "") + temps = f"{temp_min}℃ ~ {temp_max}℃" if temp_min and temp_max else "" + suggestion = data.get("suggestion", "") + + parts = [] + if city and weather: + parts.append(f"{city}{weather}") + if temps: + parts.append(f"气温{temps}") + if suggestion: + parts.append(suggestion) + + return ",".join(parts) if parts else raw + + +def _process_time_result(raw: str) -> str: + """时间结果处理器 —— 过滤多余时间戳和时区详情 + + 原始结果包含(来自 get_current_time): + datetime, date, time, weekday, year, month, day, hour, minute, second + + 精简后只保留: + 日期、时间、星期 + + 处理逻辑: + 1. 解析 JSON 提取 date/time/weekday 三个核心字段 + 2. 丢弃 year/month/day/hour/minute/second 等冗余子字段 + 3. 组装成自然语言短句 + + token 节省估算:原始 ~80 tokens → 处理后 ~25 tokens(节省 69%) + """ + data = _safe_parse_json(raw) + if data is None: + return _strip_json_wrapper(raw) + + # 只提取用户需要的三个核心字段 + date = data.get("date", "") + time_val = data.get("time", "") + weekday = data.get("weekday", "") + # 如果 date/time 为空,尝试从 datetime 或其它字段推断 + datetime_val = data.get("datetime", "") + if not date and datetime_val: + date = datetime_val[:10] + if not time_val and datetime_val: + time_val = datetime_val[11:19] + + # 组装自然语言 + parts = [] + if date: + parts.append(f"日期:{date}") + if weekday: + parts.append(f"{weekday}") + if time_val: + parts.append(f"时间:{time_val}") + + return ",".join(parts) if parts else raw + + +def _process_search_result(raw: str) -> str: + """搜索结果处理器 —— 截断过长结果,保留核心摘要 + + 原始结果可能包含大量检索文档内容,需要截断并提取核心信息。 + + 处理逻辑: + 1. 如果结果过长(>800 字符),截断并附加省略标记 + 2. 去除 JSON 包装 + 3. 保留前几条最相关的结果 + + token 节省估算:原始 ~500 tokens → 处理后 ~200 tokens(节省 60%) + """ + MAX_CHARS = 800 + text = _strip_json_wrapper(raw) + + if len(text) <= MAX_CHARS: + return text + + # 截断到最大长度,在最近的句号或换行处断开 + truncated = text[:MAX_CHARS] + # 尝试在最后一个完整句子处断开 + last_period = max( + truncated.rfind("。"), + truncated.rfind("\n"), + truncated.rfind(". "), + ) + if last_period > MAX_CHARS // 2: + truncated = truncated[:last_period + 1] + + return f"{truncated}\n…(结果已截断,共 {len(text)} 字符)" + + +def _process_calculate_result(raw: str) -> str: + """计算结果处理器 —— 只保留表达式和计算结果 + + 原始结果(JSON):{"expression": "3+5", "result": 8} + 精简后:计算 3+5,结果 = 8 + + 处理逻辑: + 1. 解析 JSON 提取 expression 和 result + 2. 忽略所有元数据字段 + 3. 组装成一行简洁结果 + + token 节省估算:原始 ~40 tokens → 处理后 ~15 tokens(节省 63%) + """ + data = _safe_parse_json(raw) + if data is None: + return _strip_json_wrapper(raw) + + expression = data.get("expression", "") + result = data.get("result", "") + if result is not None: + return f"计算:{expression} = {result}" + return raw + + +def _process_web_search_result(raw: str) -> str: + """网页搜索结果处理器 —— 保留搜索查询和结果摘要 + + 处理逻辑:同 _process_search_result,针对网页搜索场景 + """ + return _process_search_result(raw) + + +def _process_transfer_result(raw: str) -> str: + """Agent 转交结果处理器 —— 保留转交目标和任务描述 + + 处理逻辑: + 1. 解析 JSON 提取 transferred_to 和 task + 2. 过滤 agent_id 等内部元数据 + """ + data = _safe_parse_json(raw) + if data is None: + return _strip_json_wrapper(raw) + + agent = data.get("transferred_to", data.get("agent_name", "")) + task = data.get("task", "") + if agent and task: + return f"已将任务「{task}」转交给 Agent「{agent}」" + if agent: + return f"已转交给 Agent「{agent}」" + return _strip_json_wrapper(raw) + + +def _process_generic(raw: str) -> str: + """通用兜底处理器 —— 针对未定义专属处理器的工具 + + 自动执行以下精简操作: + 1. 去除 JSON 外层包装结构 + 2. 过滤 null/空字符串/空列表等空值字段 + 3. 去除过度的缩进和格式化空白 + 4. 保留核心文本内容 + + 处理逻辑: + 1. 尝试解析 JSON,提取所有非空字段的值 + 2. 如果是纯文本,去除多余空白 + 3. 过长文本截断处理 + """ + data = _safe_parse_json(raw) + if data is None: + # 非 JSON,做基本文本清理 + return _clean_text(raw) + + # 遍历 JSON 提取非空核心字段 + core_parts: list[str] = [] + for key, value in data.items(): + if value is None or value == "" or value == [] or value == {}: + continue + if isinstance(value, str): + core_parts.append(value) + elif isinstance(value, (int, float, bool)): + core_parts.append(f"{key}: {value}") + elif isinstance(value, list): + # 列表只取前 3 项 + items = [str(v) for v in value[:3] if v is not None] + if items: + core_parts.append(",".join(items)) + if len(value) > 3: + core_parts.append(f"(共 {len(value)} 项)") + elif isinstance(value, dict): + # 嵌套字典扁平化 + flat = _flatten_dict(value) + if flat: + core_parts.append(flat) + + result = "\n".join(core_parts) if core_parts else _strip_json_wrapper(raw) + # 过长时截断 + if len(result) > 1000: + return _process_search_result(result) + return result + + +# ============================================================================= +# 处理器注册表 —— 工具名 → 处理函数 +# ============================================================================= + +TOOL_PROCESSORS: dict[str, callable] = { + "get_weather": _process_weather_result, + "get_current_time": _process_time_result, + "search": _process_search_result, + "calculate": _process_calculate_result, + "web_search": _process_web_search_result, + "transfer_to_agent": _process_transfer_result, +} + + +# ============================================================================= +# 内部辅助函数 +# ============================================================================= + +def _safe_parse_json(text: str) -> dict | None: + """安全解析 JSON —— 失败返回 None,绝不抛异常""" + if not text or not text.strip(): + return None + text = text.strip() + # 只处理以 { 或 [ 开头的内容 + if not (text.startswith("{") or text.startswith("[")): + return None + try: + return json.loads(text) + except (json.JSONDecodeError, ValueError): + return None + + +def _strip_json_wrapper(text: str) -> str: + """去除 JSON 外层结构,提取纯文本内容 + + 如果 text 本身就是自然语言(不含 JSON 结构),原样返回。 + 如果 text 是 JSON,尝试提取其中最有意义的字符串字段。 + """ + data = _safe_parse_json(text) + if data is None: + return text + + # 如果是字典,尝试取第一个有意义的值 + if isinstance(data, dict): + for key in ["formatted", "content", "text", "summary", "result"]: + val = data.get(key) + if isinstance(val, str) and val.strip(): + return val + # 取第一个字符串值 + for val in data.values(): + if isinstance(val, str) and val.strip(): + return val + elif isinstance(data, str): + return data + elif isinstance(data, (int, float, bool)): + return str(data) + elif isinstance(data, list): + parts = [str(v) for v in data[:3] if v is not None] + return ",".join(parts) if parts else text + + return text + + +def _flatten_dict(d: dict, max_depth: int = 2) -> str: + """扁平化嵌套字典,提取非空值""" + parts: list[str] = [] + for key, value in d.items(): + if value is None or value == "": + continue + if isinstance(value, dict) and max_depth > 0: + inner = _flatten_dict(value, max_depth - 1) + if inner: + parts.append(f"{key}({inner})") + elif isinstance(value, (str, int, float, bool)): + parts.append(f"{key}: {value}") + return ",".join(parts) + + +def _clean_text(text: str) -> str: + """基本文本清理:去除多余空白和空行""" + # 去除前导空白 + text = text.strip() + # 压缩多个空行为单个空行 + text = re.sub(r"\n{3,}", "\n\n", text) + # 压缩多余空格(但不压缩中文之间的空格) + text = re.sub(r"[ \t]{2,}", " ", text) + return text + + +# ============================================================================= +# 对比验证 —— 展示原始 vs 精简效果(python -m app.utils.tool_result_processor) +# ============================================================================= +if __name__ == "__main__": + # 模拟工具原始返回结果 + mock_results = [ + # ----- 天气工具:模拟 Open-Meteo API 原始返回 ----- + ( + "get_weather", + json.dumps({ + "city": "北京", + "date": "2026-05-06", + "forecast_days": [{ + "date": "2026-05-06", + "weather": "晴", + "temp_min": 14.2, + "temp_max": 23.2, + "wind_scale": "大风", + "wind_direction": "北", + "humidity": "25%", + "precipitation_probability": 0, + }], + }, ensure_ascii=False), + ), + # ----- 天气工具(已格式化版本,weather_tool 口语化输出)----- + ( + "get_weather", + json.dumps({ + "formatted": "天气晴好正适合出门,「北京」现在是晴天,气温在14.2℃到23.2℃之间。当前北大风,体感温度会偏低一些。早晚有些凉,最好备一件外搭。气温舒适宜人,穿件衬衫或薄长袖就刚好。风力较大,外出注意防风,尽量远离广告牌。" + }, ensure_ascii=False), + ), + # ----- 天气工具(预报格式)----- + ( + "get_weather", + json.dumps({ + "formatted": "「广州」明天天气预报来啦——预计小雨,气温23.1℃ ~ 28.8℃。东北清风,降水概率94%,建议安排室内活动。出门别忘了带把伞。" + }, ensure_ascii=False), + ), + # ----- 时间工具:模拟 get_current_time 原始返回 ----- + ( + "get_current_time", + json.dumps({ + "datetime": "2026-05-03 15:30:00", + "date": "2026-05-03", + "time": "15:30:00", + "weekday": "星期一", + "year": 2026, + "month": 5, + "day": 3, + "hour": 15, + "minute": 30, + "second": 0, + }, ensure_ascii=False), + ), + # ----- 计算工具 ----- + ( + "calculate", + json.dumps({"expression": "165 * 38", "result": 6270}, ensure_ascii=False), + ), + # ----- 搜索工具(模拟长结果)----- + ( + "search", + "检索到以下内容:消防工程师考试2026年报名时间为6月1日至6月30日,考试科目包括消防安全技术实务、消防安全技术综合能力、消防安全案例分析三门。" + + "报名条件要求大专以上学历,从事消防工作满6年。考试费用每科65元。" + + "(此段为冗余扩展内容,用于测试搜索结果的截断处理效果,正常情况下大模型不需要看到这么长的原始结果文本," * 3 + + "实际场景中原始搜索返回可能包含数千字符,通过处理器截断后仅保留核心信息)", + ), + # ----- Agent 转交 ----- + ( + "transfer_to_agent", + json.dumps({ + "transferred_to": "消防专家Agent", + "task": "解答关于消防通道设计的规范要求", + "agent_id": "agent-001", + }, ensure_ascii=False), + ), + # ----- 未注册工具(走通用兜底)----- + ( + "unknown_tool", + json.dumps({ + "status": "ok", + "data": "操作成功", + "error": None, + "timestamp": "2026-05-03T15:30:00Z", + "trace_id": "", + }, ensure_ascii=False), + ), + ] + + print("=" * 80) + print(" 工具结果处理器 —— 原始 vs 精简 对比验证") + print("=" * 80) + + total_raw_chars = 0 + total_processed_chars = 0 + + for tool_name, raw_result in mock_results: + processed = process_tool_result(tool_name, raw_result) + total_raw_chars += len(raw_result) + total_processed_chars += len(processed) + + raw_preview = raw_result[:120] + "…" if len(raw_result) > 120 else raw_result + processed_preview = processed[:120] + "…" if len(processed) > 120 else processed + + print(f"\n ┌─ 工具: {tool_name}") + print(f" ├─ 原始 ({len(raw_result)} 字符): {raw_preview}") + print(f" └─ 精简 ({len(processed)} 字符): {processed_preview}") + + # 汇总统计 + print() + print("=" * 80) + print(" Token 消耗对比汇总") + print("=" * 80) + # 中文约 1 token ≈ 1.5 字符 + raw_tokens_est = total_raw_chars // 2 + processed_tokens_est = total_processed_chars // 2 + savings = raw_tokens_est - processed_tokens_est + savings_pct = (savings / raw_tokens_est * 100) if raw_tokens_est > 0 else 0 + + print(f" 原始总字符数: {total_raw_chars} 字符") + print(f" 精简后字符数: {total_processed_chars} 字符") + print(f" 估算原始 tokens: ~{raw_tokens_est}") + print(f" 估算精简 tokens: ~{processed_tokens_est}") + print(f" 节省 tokens: ~{savings}") + print(f" 节省比例: {savings_pct:.1f}%") + print() + print(" " + ("=" * 76)) + + if savings_pct >= 37: + print(f" [PASS] Token 节省比例 {savings_pct:.1f}% >= 37%,目标达成") + else: + print(f" [WARN] Token 节省比例 {savings_pct:.1f}% < 37%,可进一步优化") diff --git a/backend/app/utils/weather_tool.py b/backend/app/utils/weather_tool.py new file mode 100644 index 0000000..b121c4e --- /dev/null +++ b/backend/app/utils/weather_tool.py @@ -0,0 +1,949 @@ +""" +天气工具模块 —— 本地毫秒级天气查询,零大模型调用 + +功能: + 对接 Open-Meteo 免费天气 API(无需 API 密钥),支持城市实时天气查询、 + 指定日期预报查询(最多7天)、5 分钟内存缓存、口语化自然语言回复生成、 + 场景化出行/穿搭/防晒建议、全链路异常兜底。 + +核心流程: + 1. 用户消息 → 提取城市名 + 解析日期 → 地理编码(城市→经纬度) + 2. 经纬度 → 获取天气数据(实时 current 或每日 daily 预报) + 3. WMO 天气码 → 中文天气描述 + 4. 结构化数据 → 自然语言口语化回复 + 场景化建议 + +设计原则: + 1. 类+全局单例模式,与 time_tool.py 完全对齐 + 2. 5分钟LRU缓存,按"城市_日期"粒度,实时与预报分开缓存 + 3. 三层降级:API成功 → 兜底常识库 → 友好话术 + 4. 所有异常不抛向调用方,返回友好兜底 + 5. 纯本地封装,不改动现有项目的其他逻辑 +""" + +import asyncio +import hashlib +import json +import re +from enum import Enum +from functools import lru_cache +from datetime import datetime, timezone, timedelta +from loguru import logger + +import httpx + + +# ============================================================================= +# 日期查询类型枚举 +# ============================================================================= + +class DateType(Enum): + """日期查询类型""" + TODAY = "today" # 今日实时天气 + FORECAST = "forecast" # 预报日期(1-7天内) + OUT_OF_RANGE = "out_of_range" # 超出预报范围 + + +# ============================================================================= +# WMO 天气码 → 中文天气描述映射 +# ============================================================================= + +WMO_WEATHER_MAP: dict[int, str] = { + 0: "晴天", + 1: "大部晴朗", + 2: "多云间晴", + 3: "多云", + 45: "有雾", + 48: "雾凇", + 51: "小毛毛雨", + 53: "中等毛毛雨", + 55: "大毛毛雨", + 56: "小冻毛毛雨", + 57: "大冻毛毛雨", + 61: "小雨", + 63: "中雨", + 65: "大雨", + 66: "小冻雨", + 67: "大冻雨", + 71: "小雪", + 73: "中雪", + 75: "大雪", + 77: "雪粒", + 80: "小阵雨", + 81: "中等阵雨", + 82: "大阵雨", + 85: "小阵雪", + 86: "大阵雪", + 95: "雷暴", + 96: "小冰雹雷暴", + 99: "大冰雹雷暴", +} + +# ============================================================================= +# 风力等级 → 中文描述映射(蒲福风级) +# ============================================================================= + +WIND_SCALE_MAP: dict[int, str] = { + 0: "无风", + 1: "微风", + 2: "轻风", + 3: "和风", + 4: "清风", + 5: "劲风", + 6: "强风", + 7: "疾风", + 8: "大风", + 9: "烈风", + 10: "狂风", + 11: "暴风", + 12: "飓风", +} + +# ============================================================================= +# 常用城市→经纬度兜底映射(减少地理编码 API 调用) +# ============================================================================= + +CITY_COORDINATES: dict[str, tuple[float, float]] = { + "北京": (39.9042, 116.4074), + "上海": (31.2304, 121.4737), + "广州": (23.1291, 113.2644), + "深圳": (22.5431, 114.0579), + "杭州": (30.2741, 120.1551), + "成都": (30.5728, 104.0668), + "武汉": (30.5928, 114.3055), + "西安": (34.3416, 108.9398), + "南京": (32.0603, 118.7969), + "重庆": (29.4316, 106.9123), + "天津": (39.3434, 117.3616), + "苏州": (31.2990, 120.5853), + "长沙": (28.2282, 112.9388), + "郑州": (34.7466, 113.6253), + "济南": (36.6512, 117.1201), + "青岛": (36.0671, 120.3826), + "大连": (38.9140, 121.6147), + "厦门": (24.4798, 118.0894), + "福州": (26.0745, 119.2965), + "昆明": (25.0389, 102.7183), + "贵阳": (26.6470, 106.6302), + "南宁": (22.8170, 108.3665), + "海口": (20.0440, 110.1999), + "三亚": (18.2528, 109.5120), + "哈尔滨": (45.8038, 126.5350), + "长春": (43.8171, 125.3235), + "沈阳": (41.8057, 123.4315), + "乌鲁木齐": (43.8256, 87.6168), + "拉萨": (29.6500, 91.1000), + "兰州": (36.0611, 103.8343), + "银川": (38.4872, 106.2309), + "西宁": (36.6171, 101.7785), + "呼和浩特": (40.8426, 111.7490), + "太原": (37.8706, 112.5489), + "石家庄": (38.0428, 114.5149), + "合肥": (31.8206, 117.2272), + "南昌": (28.6820, 115.8579), +} + + +def _minute_bucket_key(key: str) -> str: + """生成带 5 分钟粒度的缓存键 + + 在当前分钟向下取整到 5 的倍数后追加原始键, + 实现"5分钟内相同查询走缓存,5分钟后自动失效"。 + + 参数: + key: 原始缓存键(城市名+日期) + + 返回: + 带时间戳的缓存键,如 "2026-05-06T15:30_北京_2026-05-06" + """ + now = datetime.now(timezone.utc) + minute = now.minute // 5 * 5 + ts = now.replace(minute=minute, second=0, microsecond=0).isoformat() + return f"{ts}_{key}" + + +# ============================================================================= +# 日期解析 —— 口语化日期 → 标准 YYYY-MM-DD +# ============================================================================= + +# 中文数字映射 +_CN_NUM = { + "零": 0, "一": 1, "二": 2, "三": 3, "四": 4, + "五": 5, "六": 6, "七": 7, "八": 8, "九": 9, "十": 10, +} + +_WEEKDAY_NUM = {"一": 0, "二": 1, "三": 2, "四": 3, "五": 4, "六": 5, "日": 6} + +_WEEKDAY_CN = { + "周一": 0, "周二": 1, "周三": 2, "周四": 3, "周五": 4, "周六": 5, "周日": 6, +} +for ch, idx in _WEEKDAY_NUM.items(): + _WEEKDAY_CN[f"星期{ch}"] = idx + + +def parse_query_date(date_str: str) -> dict: + """将用户输入的口语化日期解析为标准 YYYY-MM-DD 格式 + + 支持的日期格式: + - "5.1号"、"5月1日"、"5.1"、"5/1"、"2026-05-01"(数字日期) + - "今天"、"今日"、"明天"、"明日"、"后天"、"大后天"(相对日期) + - "3天后"、"三天后"、"3天后"(偏移日期) + - "下周一"、"下周二"…"下周日"(下周) + - "昨天"、"前天"(历史日期) + + 参数: + date_str: 用户输入的原始日期文本 + + 返回: + 字典: + - "date": YYYY-MM-DD 格式日期 + - "type": DateType 枚举值(TODAY / FORECAST / OUT_OF_RANGE) + - "day_offset": 距离今天的天数(0=今天,1=明天,负数=过去) + - "error": 解析失败时的提示信息(成功时为 None) + """ + if not date_str or not date_str.strip(): + return {"date": datetime.now().strftime("%Y-%m-%d"), "type": DateType.TODAY, "day_offset": 0, "error": None} + + text = date_str.strip() + today = datetime.now().date() + max_forecast = 7 # Open-Meteo 免费 API 最多支持 7 天预报 + + # ---- 1. 相对日期(今天/明天/后天/大后天/昨天/前天)---- + relative_map = { + "大前天": -3, "前天": -2, "昨天": -1, + "今天": 0, "今日": 0, + "明天": 1, "明日": 1, + "后天": 2, + "大后天": 3, + } + for word, offset in relative_map.items(): + if word in text: + target_date = today + timedelta(days=offset) + delta = (target_date - today).days + if delta < 0: + dt = DateType.OUT_OF_RANGE + elif delta == 0: + dt = DateType.TODAY + elif delta <= max_forecast: + dt = DateType.FORECAST + else: + dt = DateType.OUT_OF_RANGE + return {"date": target_date.strftime("%Y-%m-%d"), "type": dt, "day_offset": delta, "error": None} + + # ---- 2. N天后(如 "3天后"、"三天后")---- + offset_match = re.match(r"(\d+|[一二三四五六七八九十]+)\s*天?后", text) + if offset_match: + raw = offset_match.group(1) + if raw.isdigit(): + offset = int(raw) + else: + offset = 0 + for ch in raw: + if ch == "十": + offset = max(offset, 1) * 10 + elif ch in _CN_NUM: + offset += _CN_NUM[ch] + target_date = today + timedelta(days=offset) + delta = (target_date - today).days + dt = DateType.FORECAST if delta <= max_forecast else DateType.OUT_OF_RANGE + return {"date": target_date.strftime("%Y-%m-%d"), "type": dt, "day_offset": delta, "error": None} + + # ---- 3. 下周一/下周二…下周日 ---- + for week_word, weekday_idx in _WEEKDAY_CN.items(): + if week_word in text: + import calendar + today_weekday = today.weekday() + days_until = (weekday_idx - today_weekday) % 7 + if days_until == 0: + days_until = 7 # "下周一"含义是下周,不是本周 + if days_until <= 7: + days_until = days_until + 7 if days_until <= 0 else days_until + target_date = today + timedelta(days=days_until) + delta = (target_date - today).days + dt = DateType.FORECAST if delta <= max_forecast else DateType.OUT_OF_RANGE + return {"date": target_date.strftime("%Y-%m-%d"), "type": dt, "day_offset": delta, "error": None} + + # ---- 4. 数字日期(5.1号 / 5月1日 / 5.1 / 5/1 / 2026-05-01)---- + # 匹配 "5.1号"、"5月1日"、"2026-05-01"、"5.1"、"5/1"、"5-1" + num_match = re.search( + r"((?P\d{4})[-/\.年])?" + r"(?P\d{1,2})" + r"[-/\.月]" + r"(?P\d{1,2})" + r"[日号]?", + text + ) + if num_match: + year = int(num_match.group("year")) if num_match.group("year") else today.year + month = int(num_match.group("month")) + day = int(num_match.group("day")) + try: + target_date = datetime(year=year, month=month, day=day).date() + except ValueError: + return {"date": today.strftime("%Y-%m-%d"), "type": DateType.TODAY, "day_offset": 0, + "error": f"日期 {year}-{month}-{day} 不存在,已为你查询今天天气"} + + delta = (target_date - today).days + if delta < 0: + dt = DateType.OUT_OF_RANGE + elif delta == 0: + dt = DateType.TODAY + elif delta <= max_forecast: + dt = DateType.FORECAST + else: + dt = DateType.OUT_OF_RANGE + return {"date": target_date.strftime("%Y-%m-%d"), "type": dt, "day_offset": delta, "error": None} + + # ---- 兜底:无法解析,按今天处理 ---- + return {"date": today.strftime("%Y-%m-%d"), "type": DateType.TODAY, "day_offset": 0, + "error": f"日期格式无法识别,已为你查询今天天气"} + + +# ============================================================================= +# WeatherTool 核心类 +# ============================================================================= + +class WeatherTool: + """本地天气工具 —— 毫秒级天气查询,零大模型调用 + + 设计参考 time_tool.py,采用相同的类+全局单例架构。 + + 用法: + tool = WeatherTool(default_city="北京") + reply = await tool.get_reply("上海") # 查询上海今天天气 + reply = await tool.get_reply("北京", days=3) # 查询北京3天预报 + """ + + def __init__(self, default_city: str = "北京"): + """初始化天气工具 + + 参数: + default_city: 默认城市,用户未指定城市时使用。 + 后续可对接用户记忆系统的所在地字段。 + """ + self.default_city = default_city + self._geocoding_cache: dict[str, tuple[float, float]] = {} + + # ------------------------------------------------------------------ + # 公开接口 + # ------------------------------------------------------------------ + + async def get_reply(self, city: str | None = None, date_str: str = "") -> str: + """获取天气自然语言回复(主入口) + + 参数: + city: 城市名称,None 时使用默认城市 + date_str: 用户输入的原始日期文本, + 支持 "今天"、"明天"、"5.1号"、"下周一" 等口语化格式, + 空字符串默认今天 + + 返回: + 自然语言天气回复字符串 + + 异常安全: + 任意环节异常均不抛异常,返回友好兜底话术 + """ + city = city or self.default_city + if not city or not city.strip(): + city = self.default_city + + city = city.strip() + + try: + # 解析日期 + parsed = parse_query_date(date_str) + target_date = parsed["date"] + date_type = parsed["type"] + day_offset = parsed["day_offset"] + parse_error = parsed.get("error") + + # 超出预报范围 → 直接返回友好提示 + if date_type == DateType.OUT_OF_RANGE: + if day_offset < 0: + return self._fallback(city, f"「{city}」{abs(day_offset)}天前的天气数据已超出查询范围,我只能查询今天及未来7天的天气哦~") + return self._fallback(city, f"「{city}」{target_date}的天气预报已超出7天查询范围,我只支持查询今天及未来7天的天气哦~") + + # 缓存键:城市_日期 + cache_key_raw = f"{city}_{target_date}" + cache_key = _minute_bucket_key(cache_key_raw) + + # LRU 缓存取数据 + weather_data = _cached_weather_fetch(city, cache_key, 1) + if weather_data is None: + coordinates = await self._get_coordinates(city) + if coordinates is None: + return self._fallback(city, f"找不到「{city}」的地理位置,换个城市试试?") + lat, lon = coordinates + # 获取足够的预报天数(也包含今天) + fetch_days = max(day_offset + 1, 1) + weather_data = await self._fetch_weather(lat, lon, fetch_days) + if weather_data is None: + return self._fallback(city, f"暂时无法获取「{city}」的天气数据,请稍后再试哦~") + _cache_weather_result(cache_key, json.dumps(weather_data, ensure_ascii=False)) + + # 根据日期类型生成对应回复 + if date_type == DateType.TODAY: + reply = self._format_single_day(city, weather_data, day_offset=0) + else: + reply = self._format_forecast_reply(city, weather_data, target_date, day_offset) + + # 若有日期解析提示,附在末尾 + if parse_error and "无法识别" not in parse_error: + reply += f"({parse_error})" + + return reply + + except httpx.TimeoutException: + logger.warning(f"[WeatherTool] API请求超时: {city}") + return self._fallback(city, "天气API响应超时,建议过会儿再查~") + except httpx.HTTPStatusError as e: + logger.warning(f"[WeatherTool] API返回错误: {city}, status={e.response.status_code}") + return self._fallback(city, "天气服务暂时不可用,稍后再试吧~") + except Exception as e: + logger.warning(f"[WeatherTool] 获取天气异常: {city}, error={e}") + return self._fallback(city) + + # ------------------------------------------------------------------ + # 地理编码 —— 城市名→经纬度 + # ------------------------------------------------------------------ + + async def _get_coordinates(self, city: str) -> tuple[float, float] | None: + """将城市名转换为经纬度坐标 + + 三级查找策略: + 1. 内置映射表 CITY_COORDINATES(0ms,零网络开销) + 2. 运行时缓存(同进程复用) + 3. Open-Meteo Geocoding API + + 参数: + city: 城市名称 + + 返回: + (纬度, 经度) 元组,失败返回 None + """ + # 第一层:内置映射表 + coords = CITY_COORDINATES.get(city) + if coords: + return coords + + # 第二层:运行时缓存 + coords = self._geocoding_cache.get(city) + if coords: + return coords + + # 第三层:Geocoding API + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.get( + "https://geocoding-api.open-meteo.com/v1/search", + params={ + "name": city, + "count": 1, + "language": "zh", + "format": "json", + }, + ) + resp.raise_for_status() + data = resp.json() + results = data.get("results", []) + if results: + lat = results[0]["latitude"] + lon = results[0]["longitude"] + self._geocoding_cache[city] = (lat, lon) + return (lat, lon) + except Exception as e: + logger.debug(f"[WeatherTool] 地理编码失败: {city}, error={e}") + + return None + + # ------------------------------------------------------------------ + # 天气数据获取 —— 经纬度→天气 + # ------------------------------------------------------------------ + + async def _fetch_weather(self, lat: float, lon: float, days: int = 1) -> dict | None: + """从 Open-Meteo API 获取天气预报数据 + + 参数: + lat: 纬度 + lon: 经度 + days: 预报天数(1-7) + + 返回: + 结构化天气数据字典,失败返回 None + + Open-Meteo 免费 API 说明: + - 无需注册,无需 API 密钥 + - 速率限制:10,000次/天 + - 返回 daily 级数据:最高温/最低温/天气码/风力/相对湿度/降水概率 + """ + try: + async with httpx.AsyncClient(timeout=15.0) as client: + resp = await client.get( + "https://api.open-meteo.com/v1/forecast", + params={ + "latitude": lat, + "longitude": lon, + "daily": [ + "temperature_2m_max", + "temperature_2m_min", + "weathercode", + "windspeed_10m_max", + "winddirection_10m_dominant", + "relative_humidity_2m_max", + "precipitation_probability_max", + ], + "timezone": "Asia/Shanghai", + "forecast_days": min(days, 7), # 限制最大7天 + }, + ) + resp.raise_for_status() + data = resp.json() + daily = data.get("daily", {}) + + if not daily: + return None + + # 组装结构化天气数据 + dates = daily.get("time", []) + temps_max = daily.get("temperature_2m_max", []) + temps_min = daily.get("temperature_2m_min", []) + weather_codes = daily.get("weathercode", []) + wind_speeds = daily.get("windspeed_10m_max", []) + wind_dirs = daily.get("winddirection_10m_dominant", []) + humidities = daily.get("relative_humidity_2m_max", []) + precip_probs = daily.get("precipitation_probability_max", []) + + forecast_days = [] + for i in range(min(days, len(dates))): + wmo_code = int(weather_codes[i]) if i < len(weather_codes) else 0 + wind_speed = wind_speeds[i] if i < len(wind_speeds) else 0 + wind_dir = int(wind_dirs[i]) if i < len(wind_dirs) else 0 + humidity = int(humidities[i]) if i < len(humidities) else 0 + precip_prob = int(precip_probs[i]) if i < len(precip_probs) else 0 + + forecast_days.append({ + "date": dates[i], + "weather": WMO_WEATHER_MAP.get(wmo_code, "未知"), + "temp_max": round(temps_max[i], 1) if i < len(temps_max) else None, + "temp_min": round(temps_min[i], 1) if i < len(temps_min) else None, + "wind_speed": round(wind_speed, 1), + "wind_direction": self._wind_direction_name(wind_dir), + "wind_scale": self._wind_scale_name(wind_speed), + "humidity": humidity, + "precipitation_probability": precip_prob, + }) + + return {"forecast_days": forecast_days} + + except httpx.TimeoutException: + logger.warning(f"[WeatherTool] 天气API超时: lat={lat}, lon={lon}") + except httpx.HTTPStatusError as e: + logger.warning(f"[WeatherTool] 天气API错误: status={e.response.status_code}") + except Exception as e: + logger.warning(f"[WeatherTool] 天气API异常: {e}") + + return None + + # ------------------------------------------------------------------ + # 自然语言回复生成 + # ------------------------------------------------------------------ + + def _format_single_day(self, city: str, weather_data: dict, day_offset: int = 0) -> str: + """格式化今日实时天气回复 —— 口语化、带场景建议 + + 参数: + city: 城市名 + weather_data: _fetch_weather 返回的结构化数据 + day_offset: 日期偏移(0=今天) + """ + forecast_days = weather_data.get("forecast_days", []) + if not forecast_days: + return f"暂时没有「{city}」的天气数据哦~" + + day = forecast_days[0] if day_offset < len(forecast_days) else forecast_days[0] + weather = day.get("weather", "未知") + temp_min = day.get("temp_min", 0) + temp_max = day.get("temp_max", 0) + wind_scale = day.get("wind_scale", "") + wind_dir = day.get("wind_direction", "") + precip_prob = day.get("precipitation_probability", 0) + wmo_code = self._infer_wmo_code(day) + + # 生成丰富的场景化建议 + suggestion = self._generate_rich_suggestion( + wmo_code, temp_max, temp_min, wind_scale, precip_prob + ) + + # 口语化开头 + greeting = self._weather_greeting(weather, temp_max) + + lines = [ + f"{greeting}「{city}」现在是{weather},气温在{temp_min}℃到{temp_max}℃之间。", + ] + + # 风力信息(简洁) + if wind_scale and wind_scale != "无风": + lines.append(f"当前{wind_dir}{wind_scale},体感温度会偏低一些。") + + # 降水提示 + if precip_prob >= 60: + lines.append(f"降水概率高达{precip_prob}%,出门记得带伞哦。") + elif precip_prob >= 30: + lines.append(f"有{precip_prob}%的概率会降水,可以随身带把伞以防万一。") + + # 场景化建议 + if suggestion: + lines.append(suggestion) + + return "".join(lines) + + def _format_forecast_reply(self, city: str, weather_data: dict, target_date: str, day_offset: int) -> str: + """格式化指定日期预报回复 + + 参数: + city: 城市名 + weather_data: _fetch_weather 返回的结构化数据 + target_date: 目标日期 YYYY-MM-DD + day_offset: 距今天数 + """ + forecast_days = weather_data.get("forecast_days", []) + if not forecast_days or day_offset >= len(forecast_days): + return f"暂时没有「{city}」{target_date}的预报数据哦~" + + day = forecast_days[day_offset] + weather = day.get("weather", "未知") + temp_min = day.get("temp_min", 0) + temp_max = day.get("temp_max", 0) + wind_scale = day.get("wind_scale", "") + wind_dir = day.get("wind_direction", "") + precip_prob = day.get("precipitation_probability", 0) + wmo_code = self._infer_wmo_code(day) + + date_label = self._date_label(target_date) + + suggestion = self._generate_rich_suggestion( + wmo_code, temp_max, temp_min, wind_scale, precip_prob + ) + + lines = [f"「{city}」{date_label}天气预报来啦——"] + + # 天气核心信息 + temp_range = f"{temp_min}℃ ~ {temp_max}℃" + lines.append(f"预计{weather},气温{temp_range}。") + + # 风力 + if wind_scale and wind_scale != "无风": + lines.append(f"{wind_dir}{wind_scale},") + + # 降水 + if precip_prob >= 50: + lines.append(f"降水概率{precip_prob}%,建议安排室内活动。") + elif precip_prob >= 20: + lines.append(f"降水概率{precip_prob}%,出行前留意天气变化。") + + if suggestion: + lines.append(suggestion) + + return "".join(lines) + + def _format_multi_day(self, city: str, days_list: list[dict]) -> str: + """格式化多日天气回复(保留,用于未来扩展)""" + lines = [f"「{city}」未来{len(days_list)}天天气预报:\n"] + for day in days_list: + date_label = self._date_label(day["date"]) + temp = f"{day['temp_min']}°C ~ {day['temp_max']}°C" + lines.append( + f" {date_label}:{day['weather']},{temp}," + f"{day['wind_direction']}{day['wind_scale']}" + ) + return "\n".join(lines) + + # ------------------------------------------------------------------ + # 场景化建议生成 + # ------------------------------------------------------------------ + + @staticmethod + def _infer_wmo_code(day: dict) -> int: + """从天气数据中推断 WMO 天气码""" + weather = day.get("weather", "") + for code, desc in WMO_WEATHER_MAP.items(): + if desc == weather: + return code + return 0 + + @staticmethod + def _weather_greeting(weather: str, temp: float) -> str: + """生成天气口语化开头""" + if "雨" in weather: + if temp >= 25: + return "外面下着雨但气温不低," + elif temp <= 10: + return "阴雨绵绵天气偏冷," + return "下雨天出门注意安全," + if "雪" in weather: + return "下雪天景色很美但要小心路滑," + if "晴" in weather: + if temp >= 30: + return "阳光灿烂但气温偏高," + elif temp >= 20: + return "天气晴好正适合出门," + return "晴空万里但气温偏低," + if "云" in weather: + return "多云天气还算舒适," + if "雾" in weather: + return "雾气较重能见度低," + if "雷" in weather: + return "雷暴天气请尽量减少外出," + return "" + + @staticmethod + def _generate_rich_suggestion( + wmo_code: int, temp_max: float, temp_min: float, + wind_scale: str, precip_prob: int, + ) -> str: + """根据多维度数据生成丰富的场景化建议 + + 返回: + 口语化的场景建议字符串,多个建议用分号分隔 + """ + suggestions = [] + + # ---- 温差穿搭建议 ---- + temp_diff = temp_max - temp_min + avg_temp = (temp_max + temp_min) / 2 + if temp_diff >= 12: + suggestions.append("早晚温差大,建议带件薄外套方便随时增减") + elif temp_diff >= 8: + suggestions.append("早晚有些凉,最好备一件外搭") + + # ---- 温度穿搭建议 ---- + if avg_temp <= 5: + suggestions.append("气温很低,羽绒服围巾手套都安排上吧") + elif avg_temp <= 12: + suggestions.append("天气偏冷,适合穿厚外套或毛衣") + elif avg_temp <= 18: + suggestions.append("气温微凉,薄外套加长袖刚好") + elif avg_temp <= 25: + suggestions.append("气温舒适宜人,穿件衬衫或薄长袖就刚好") + elif avg_temp <= 30: + suggestions.append("天气偏热,短袖短裤可以安排上了") + else: + suggestions.append("高温天气,注意防暑降温多喝水") + + # ---- 防晒建议 ---- + if wmo_code in {0, 1, 2} and temp_max >= 25: + suggestions.append("紫外线较强记得涂防晒") + + # ---- 雨天建议 ---- + rain_codes = {51, 53, 55, 61, 63, 65, 80, 81, 82, 95, 96, 99} + if wmo_code in rain_codes: + if wmo_code in {65, 82, 95, 96, 99}: + suggestions.append("雨势不小出门务必带伞,路滑注意脚下") + else: + suggestions.append("出门别忘了带把伞") + elif precip_prob >= 60: + suggestions.append("虽然不一定下雨,但带把伞比较稳妥") + + # ---- 雪天建议 ---- + if wmo_code in {71, 73, 75, 77, 85, 86}: + suggestions.append("雪天路面湿滑,走路注意脚下防摔") + + # ---- 大风建议 ---- + if wind_scale in {"强风", "疾风", "大风", "烈风", "狂风", "暴风", "飓风"}: + suggestions.append("风力较大,外出注意防风,尽量远离广告牌") + + # ---- 雾天建议 ---- + if wmo_code in {45, 48}: + suggestions.append("有雾天气能见度低,开车出行请减速慢行") + + # ---- 出行建议 ---- + if wmo_code == 0 and 18 <= avg_temp <= 28 and wind_scale in {"无风", "微风", "轻风", "和风", "清风", "劲风"}: + suggestions.append("天气超棒,很适合出去走走") + + return "。".join(suggestions) + "。" if suggestions else "" + + # ------------------------------------------------------------------ + # 辅助方法 + # ------------------------------------------------------------------ + + @staticmethod + def _date_label(date_str: str) -> str: + """将日期转为自然语言标签 + + 参数: + date_str: YYYY-MM-DD 格式日期 + + 返回: + "今天"、"明天"、"后天" 或 "X月X日" + """ + try: + date = datetime.strptime(date_str, "%Y-%m-%d").date() + today = datetime.now().date() + delta = (date - today).days + if delta == 0: + return "今天" + if delta == 1: + return "明天" + if delta == 2: + return "后天" + return f"{date.month}月{date.day}日" + except (ValueError, TypeError): + return date_str + + @staticmethod + def _wind_direction_name(degrees: int) -> str: + """风向角度 → 中文方位名""" + directions = ["北", "东北", "东", "东南", "南", "西南", "西", "西北"] + index = round(degrees / 45) % 8 + return directions[index] + + @staticmethod + def _wind_scale_name(speed_mps: float) -> str: + """风速 m/s → 风力等级描述""" + if speed_mps <= 0.3: + return WIND_SCALE_MAP[0] + if speed_mps <= 1.5: + return WIND_SCALE_MAP[1] + if speed_mps <= 3.3: + return WIND_SCALE_MAP[2] + if speed_mps <= 5.4: + return WIND_SCALE_MAP[3] + if speed_mps <= 7.9: + return WIND_SCALE_MAP[4] + if speed_mps <= 10.7: + return WIND_SCALE_MAP[5] + if speed_mps <= 13.8: + return WIND_SCALE_MAP[6] + if speed_mps <= 17.1: + return WIND_SCALE_MAP[7] + return "大风" + + @staticmethod + def _fallback(city: str, message: str = None) -> str: + """异常兜底话术""" + if message: + return message + return ( + f"暂时无法获取「{city}」的天气信息。你可以在浏览器搜索「{city}天气」" + f"查看最新预报哦~" + ) + + +# ============================================================================= +# LRU 缓存层 —— 5分钟粒度,减少重复 API 调用 +# ============================================================================= + +_weather_cache: dict[str, str] = {} + + +def _cached_weather_fetch(city: str, cache_key: str, days: int) -> dict | None: + """LRU 缓存查找天气数据 + + 返回 None 表示缓存未命中,调用方需 fetch 新数据 + 返回 dict 表示缓存命中 + + 注意:LRU cache 基于 cache_key(包含时间戳),5分钟后自动过期 + """ + try: + data_json = _weather_cache.get(cache_key) + if data_json: + return json.loads(data_json) + except (json.JSONDecodeError, TypeError): + pass + return None + + +def _cache_weather_result(cache_key: str, data_json: str): + """将天气数据存入缓存""" + _weather_cache[cache_key] = data_json + # 周期性清理旧缓存(超过 20 条时清理前半部分) + if len(_weather_cache) > 20: + old_keys = sorted(_weather_cache.keys())[:len(_weather_cache) // 2] + for key in old_keys: + _weather_cache.pop(key, None) + + +# ============================================================================= +# 全局单例接口 +# ============================================================================= + +_weather_tool = WeatherTool(default_city="北京") + + +def get_weather_reply(city: str | None = None, date_str: str = "") -> str: + """全局天气查询入口 —— 同步封装 + + 仅在无运行事件循环时使用 asyncio.run(), + 若已在异步上下文中则提示调用方使用异步 API。 + + 参数: + city: 城市名称,None 时使用默认城市"北京" + date_str: 日期文本,如"今天"、"明天"、"5.1号",空字符串默认今天 + + 返回: + 自然语言天气回复字符串 + """ + try: + try: + asyncio.get_running_loop() + return "天气查询需要在异步上下文中使用 await 调用,请使用异步 API" + except RuntimeError: + pass + return asyncio.run(_weather_tool.get_reply(city, date_str)) + except Exception as e: + logger.warning(f"[WeatherTool] get_weather_reply 异常: {e}") + return f"天气查询暂时不可用,稍后再试哦~" + + +# ============================================================================= +# 直接运行验证(python -m app.utils.weather_tool) +# ============================================================================= +if __name__ == "__main__": + async def _test(): + tool = WeatherTool(default_city="北京") + + test_cases = [ + # (城市, 日期字符串, 描述) + ("北京", "", "今日实时天气"), + ("上海", "今天", "今日实时天气(指定'今天')"), + ("广州", "明天", "明日预报"), + ("深圳", "5.1号", "数字日期预报(跨月)"), + ("杭州", "后天", "后天预报"), + ("不存在城市XYZ", "", "城市不存在-兜底"), + ("成都", "10天后", "超出预报范围"), + ("武汉", "昨天", "历史日期-兜底"), + ] + + print("=" * 72) + print(" WeatherTool 天气工具 优化版 测试结果") + print("=" * 72) + + passed = 0 + failed = 0 + + for city, date_str, desc in test_cases: + display = f"{city}「{date_str or '默认'}」" + try: + reply = await tool.get_reply(city, date_str) + has_error = any(kw in reply for kw in [ + "暂时无法", "找不到", "超出", "正在查询", "Exception", + "Traceback", "KeyError", "executable handler", + ]) + status = "PASS(兜底)" if has_error else "PASS" + if not has_error: + passed += 1 + else: + passed += 1 # 异常场景兜底也是通过 + print(f"\n [{status}] {desc}: {display}") + print(f" {reply[:200]}") + except Exception as e: + failed += 1 + print(f"\n [FAIL] {desc}: {display}, 异常={e}") + + print() + print(f" 通过: {passed} 失败: {failed} 总计: {len(test_cases)}") + if failed == 0: + print("\n 全部测试通过!") + else: + print(f"\n 有 {failed} 个测试未通过,需要检查") + + asyncio.run(_test()) diff --git a/backend/app/utils/web_search_tool.py b/backend/app/utils/web_search_tool.py new file mode 100644 index 0000000..f98251d --- /dev/null +++ b/backend/app/utils/web_search_tool.py @@ -0,0 +1,214 @@ +""" +Web 搜索工具 —— 360 搜索优先 + DuckDuckGo 降级 + +功能: + 接收搜索查询,先走 360 搜索(国内可访问、服务端渲染), + 失败时降级到 DuckDuckGo 或返回空结果。 + +设计原则: + 1. 360 搜索优先(国内网络友好、结果在 HTML 中) + 2. 失败优雅降级,不抛异常,不中断对话 + 3. 单次搜索返回 Top 5 条结果摘要 +""" + +import asyncio +import re +from urllib.parse import quote + +import aiohttp +from loguru import logger + + +# ============================================================================= +# 360 搜索 —— 国内网络友好,结果在 HTML 中 +# ============================================================================= + +_360_HEADERS = { + "User-Agent": ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/125.0.0.0 Safari/537.36" + ), + "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", + "Accept-Language": "zh-CN,zh;q=0.9", + "Accept-Encoding": "gzip, deflate", + "Connection": "keep-alive", +} + + +def _extract_360_results(html: str) -> list[str]: + """从 360 搜索结果 HTML 中提取标题+摘要 + + 360 结果结构: +
  • +

    ...

    +

    ...

    +
  • + """ + results: list[str] = [] + + # 提取所有 res-list 块 + result_blocks = re.findall( + r']*class="res-list"[^>]*>.*?', + html, + re.DOTALL | re.IGNORECASE, + ) + + for block in result_blocks[:5]: + # 提取标题:h3.res-title > a + title_match = re.search( + r']*class="res-title"[^>]*>.*?]*>(.*?).*?', + block, + re.DOTALL | re.IGNORECASE, + ) + title = "" + if title_match: + title = re.sub(r'<[^>]+>', '', title_match.group(1)).strip() + title = title.replace(' ', ' ') + + # 提取摘要:span.res-list-summary 或 p.res-list-summary + snippet_match = re.search( + r'<(?:span|p)[^>]*class="res-list-summary"[^>]*>(.*?)', + block, + re.DOTALL | re.IGNORECASE, + ) + snippet = "" + if snippet_match: + snippet = re.sub(r'<[^>]+>', '', snippet_match.group(1)).strip() + snippet = snippet.replace(' ', ' ') + + if title or snippet: + results.append(f"{title}: {snippet}".strip(": ")) + + return results + + +async def _360_search(query: str) -> str | None: + """360 搜索 HTML 抓取 + + 参数: + query: 搜索查询 + + 返回: + 搜索结果摘要文本,失败返回 None + """ + encoded = quote(query) + url = f"https://www.so.com/s?q={encoded}" + + try: + async with aiohttp.ClientSession(headers=_360_HEADERS) as session: + async with session.get( + url, + timeout=aiohttp.ClientTimeout(total=15), + allow_redirects=True, + ) as resp: + if resp.status != 200: + logger.debug(f"[WebSearch] 360 状态 {resp.status}") + return None + html = await resp.text() + + results = _extract_360_results(html) + if results: + return "\n".join(results[:5]) + + return None + except Exception as e: + logger.debug(f"[WebSearch] 360 异常: {e}") + return None + + +# ============================================================================= +# DuckDuckGo Instant Answer —— 降级方案 +# ============================================================================= + +_DDG_IA_URL = "https://api.duckduckgo.com/" + + +async def _ddg_instant_answer(query: str) -> str | None: + """DuckDuckGo Instant Answer API —— 免费、无 API Key""" + params = { + "q": query, + "format": "json", + "no_html": "1", + "skip_disambig": "1", + "t": "luominest", + } + try: + async with aiohttp.ClientSession() as session: + async with session.get( + _DDG_IA_URL, + params=params, + timeout=aiohttp.ClientTimeout(total=8), + ) as resp: + if resp.status != 200: + return None + data = await resp.json() + + parts = [] + if data.get("AbstractText"): + parts.append(data["AbstractText"]) + if data.get("AbstractURL"): + parts.append(f"来源: {data['AbstractURL']}") + if data.get("Answer"): + parts.insert(0, f"直接答案: {data['Answer']}") + if data.get("RelatedTopics"): + for topic in data["RelatedTopics"][:3]: + if isinstance(topic, dict) and topic.get("Text"): + parts.append(f"- {topic['Text']}") + + if not parts: + return None + return "\n".join(parts) + except Exception: + return None + + +# ============================================================================= +# 对外统一接口 +# ============================================================================= + +async def search_web(query: str) -> str: + """对外统一接口 —— 搜索并返回最多 5 条结果摘要 + + 执行策略: + 1. 360 搜索(国内网络友好,结果在 HTML 中) + 2. 降级 → DuckDuckGo Instant Answer + 3. 双降级 → 返回提示文本 + + 参数: + query: 搜索查询 + + 返回: + 搜索结果文本,失败返回 "暂无搜索结果" + """ + logger.info(f"[WebSearch] 查询: {query}") + + # 第一优先:360 搜索 + result = await _360_search(query) + if result and len(result) > 20: + logger.info(f"[WebSearch] 360 结果: {len(result)} 字符") + return result + + # 降级:DuckDuckGo Instant Answer + result = await _ddg_instant_answer(query) + if result and len(result) > 20: + logger.info(f"[WebSearch] DDG IA 结果: {len(result)} 字符") + return result + + logger.warning("[WebSearch] 所有搜索源均失败") + return "暂无搜索结果" + + +# ============================================================================= +# 直接运行测试 +# ============================================================================= +if __name__ == "__main__": + import sys + + async def main(): + query = sys.argv[1] if len(sys.argv) > 1 else "2026年软考时间" + result = await search_web(query) + print(f"Query: {query}") + print(f"Result:\n{result}") + + asyncio.run(main()) diff --git a/frontend/package.json b/frontend/package.json index 0050c2a..1bd510f 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -18,6 +18,7 @@ }, "dependencies": { "@pixi/unsafe-eval": "^7.4.3", + "dompurify": "^3.4.2", "marked": "^18.0.1", "pinia": "^3.0.4", "pixi-live2d-display-mulmotion": "0.5.0-mm-6", @@ -25,6 +26,7 @@ "vue-router": "^5.0.4" }, "devDependencies": { + "@types/dompurify": "^3.2.0", "@types/node": "^22.13.5", "@vitejs/plugin-vue": "^6.0.5", "electron": "^41.0.0", diff --git a/frontend/pnpm-lock.yaml b/frontend/pnpm-lock.yaml index 6690caf..021b810 100644 --- a/frontend/pnpm-lock.yaml +++ b/frontend/pnpm-lock.yaml @@ -11,6 +11,9 @@ importers: '@pixi/unsafe-eval': specifier: ^7.4.3 version: 7.4.3(@pixi/core@7.4.3) + dompurify: + specifier: ^3.4.2 + version: 3.4.2 marked: specifier: ^18.0.1 version: 18.0.1 @@ -27,6 +30,9 @@ importers: specifier: ^5.0.4 version: 5.0.4(@vue/compiler-sfc@3.5.32)(pinia@3.0.4(typescript@6.0.2)(vue@3.5.32(typescript@6.0.2)))(vue@3.5.32(typescript@6.0.2)) devDependencies: + '@types/dompurify': + specifier: ^3.2.0 + version: 3.2.0 '@types/node': specifier: ^22.13.5 version: 22.19.17 @@ -915,6 +921,10 @@ packages: '@types/debug@4.1.13': resolution: {integrity: sha512-KSVgmQmzMwPlmtljOomayoR89W4FynCAi3E8PPs7vmDVPe84hT+vGPKkJfThkmXs0x0jAaa9U8uW8bbfyS2fWw==} + '@types/dompurify@3.2.0': + resolution: {integrity: sha512-Fgg31wv9QbLDA0SpTOXO3MaxySc4DKGLi8sna4/Utjo4r3ZRPdCt4UQee8BWr+Q5z21yifghREPJGYaEOEIACg==} + deprecated: This is a stub types definition. dompurify provides its own type definitions, so you do not need this installed. + '@types/earcut@2.1.4': resolution: {integrity: sha512-qp3m9PPz4gULB9MhjGID7wpo3gJ4bTGXm7ltNDsmOvsPduTeHp8wSW9YckBj3mljeOh4F0m2z/0JKAALRKbmLQ==} @@ -945,6 +955,9 @@ packages: '@types/responselike@1.0.3': resolution: {integrity: sha512-H/+L+UkTV33uf49PH5pCAUBVPNj2nDBXTN+qS1dOwyyg24l3CcicicCA7ca+HMvJBZcFgl5r8e+RR6elsb4Lyw==} + '@types/trusted-types@2.0.7': + resolution: {integrity: sha512-ScaPdn1dQczgbl0QFTeTOmVHFULt394XJgOQNoyVhZ6r2vLnMLJfBPd53SB52T/3G36VI1/g2MZaX0cwDuXsfw==} + '@types/verror@1.10.11': resolution: {integrity: sha512-RlDm9K7+o5stv0Co8i8ZRGxDbrTxhJtgjqjFyVh/tXQyl/rYtTKlnTvZ88oSTeYREWurwx20Js4kTuKCsFkUtg==} @@ -1351,6 +1364,9 @@ packages: os: [darwin] hasBin: true + dompurify@3.4.2: + resolution: {integrity: sha512-lHeS9SA/IKeIFFyYciHBr2n0v1VMPlSj843HdLOwjb2OxNwdq9Xykxqhk+FE42MzAdHvInbAolSE4mhahPpjXA==} + dotenv-expand@11.0.7: resolution: {integrity: sha512-zIHwmZPRshsCdpMDyVsqGmgyP0yT8GAgXUnkdAoJisxvf33k7yO6OuoKmcTGuXPWSsm8Oh88nZicRLA9Y0rUeA==} engines: {node: '>=12'} @@ -3286,6 +3302,10 @@ snapshots: dependencies: '@types/ms': 2.1.0 + '@types/dompurify@3.2.0': + dependencies: + dompurify: 3.4.2 + '@types/earcut@2.1.4': {} '@types/estree@1.0.8': {} @@ -3320,6 +3340,9 @@ snapshots: dependencies: '@types/node': 22.19.17 + '@types/trusted-types@2.0.7': + optional: true + '@types/verror@1.10.11': optional: true @@ -3828,6 +3851,10 @@ snapshots: verror: 1.10.1 optional: true + dompurify@3.4.2: + optionalDependencies: + '@types/trusted-types': 2.0.7 + dotenv-expand@11.0.7: dependencies: dotenv: 16.6.1 diff --git a/frontend/src/main/index.ts b/frontend/src/main/index.ts index 48ecca9..b09d85f 100644 --- a/frontend/src/main/index.ts +++ b/frontend/src/main/index.ts @@ -496,6 +496,11 @@ function registerIpcHandlers(): void { return clearBrowserData() }) + ipcMain.handle('browser:search', async (_e, query: string) => { + const { browserSearch } = await import('./services/browser') + return await browserSearch(query, mainWindow) + }) + ipcMain.handle('avatar:importModel', async () => { try { const result = await dialog.showOpenDialog({ diff --git a/frontend/src/main/services/browser/index.ts b/frontend/src/main/services/browser/index.ts index f56ae1f..7b8973d 100644 --- a/frontend/src/main/services/browser/index.ts +++ b/frontend/src/main/services/browser/index.ts @@ -2,3 +2,5 @@ export * from './types' export { tabManager } from './tab' export { initBrowserSession, clearBrowserData, getCookies } from './session' export { createBrowserView, calculateBounds, setViewBounds, setupNetworkConfig } from './view' +export { browserSearch } from './search' +export type { SearchResult } from './search' diff --git a/frontend/src/main/services/browser/search.ts b/frontend/src/main/services/browser/search.ts new file mode 100644 index 0000000..399eb69 --- /dev/null +++ b/frontend/src/main/services/browser/search.ts @@ -0,0 +1,82 @@ +import { WebContentsView, BrowserWindow } from 'electron' +import { createBrowserView, attachView, detachView, setViewBounds, isViewDestroyed } from './view' +import { calculateBounds } from './view' +import { DEFAULT_BROWSER_CONFIG } from './types' + +export interface SearchResult { + title: string + snippet: string + url: string +} + +const BING_SEARCH_URL = 'https://cn.bing.com/search' +const SEARCH_TIMEOUT = 15000 + +export async function browserSearch( + query: string, + mainWindow: BrowserWindow | null +): Promise { + if (!mainWindow || mainWindow.isDestroyed()) { + return [] + } + + const encodedQuery = encodeURIComponent(query) + const searchUrl = `${BING_SEARCH_URL}?q=${encodedQuery}&setmkt=zh-CN&setlang=zh` + + const view = createBrowserView() + + try { + view.webContents.setBackgroundThrottling(false) + + await new Promise((resolve, reject) => { + const timeout = setTimeout(() => { + reject(new Error('Search timeout')) + }, SEARCH_TIMEOUT) + + view.webContents.on('did-finish-load', () => { + clearTimeout(timeout) + resolve() + }) + + view.webContents.on('did-fail-load', (_event, errorCode) => { + clearTimeout(timeout) + reject(new Error(`Load failed: ${errorCode}`)) + }) + + view.webContents.loadURL(searchUrl).catch((err: Error) => { + clearTimeout(timeout) + reject(err) + }) + }) + + const results = await view.webContents.executeJavaScript(` + (function() { + var items = document.querySelectorAll('li.b_algo'); + var results = []; + for (var i = 0; i < Math.min(items.length, 5); i++) { + var el = items[i]; + var titleEl = el.querySelector('h2 a'); + var snippetEl = el.querySelector('p') || el.querySelector('.b_caption p'); + var title = titleEl ? titleEl.textContent.trim() : ''; + var snippet = snippetEl ? snippetEl.textContent.trim() : ''; + var url = titleEl ? titleEl.href : ''; + if (title || snippet) { + results.push({ title: title, snippet: snippet, url: url }); + } + } + return results; + })() + `) as SearchResult[] + + return results || [] + } catch (err) { + console.warn('[BrowserSearch] Search failed:', err) + return [] + } finally { + try { + if (!isViewDestroyed(view)) { + view.webContents.close() + } + } catch {} + } +} diff --git a/frontend/src/preload/index.ts b/frontend/src/preload/index.ts index f71e6eb..46b8e33 100644 --- a/frontend/src/preload/index.ts +++ b/frontend/src/preload/index.ts @@ -60,6 +60,10 @@ const api = { clearData: () => ipcRenderer.invoke('tab:clearData') }, + browserSearch: { + search: (query: string) => ipcRenderer.invoke('browser:search', query) + }, + avatar: { importModel: () => ipcRenderer.invoke('avatar:importModel'), listImportedModels: () => ipcRenderer.invoke('avatar:listImportedModels'), diff --git a/frontend/src/renderer/src/components/FileCard.vue b/frontend/src/renderer/src/components/FileCard.vue new file mode 100644 index 0000000..eab6317 --- /dev/null +++ b/frontend/src/renderer/src/components/FileCard.vue @@ -0,0 +1,146 @@ + + + + + diff --git a/frontend/src/renderer/src/components/FilePreview.vue b/frontend/src/renderer/src/components/FilePreview.vue new file mode 100644 index 0000000..72d7b46 --- /dev/null +++ b/frontend/src/renderer/src/components/FilePreview.vue @@ -0,0 +1,242 @@ + + + + + diff --git a/frontend/src/renderer/src/components/FileUpload.vue b/frontend/src/renderer/src/components/FileUpload.vue new file mode 100644 index 0000000..ee363ba --- /dev/null +++ b/frontend/src/renderer/src/components/FileUpload.vue @@ -0,0 +1,42 @@ + + + + + diff --git a/frontend/src/renderer/src/components/TitleBar.vue b/frontend/src/renderer/src/components/TitleBar.vue index fbbab22..8ccbe6b 100644 --- a/frontend/src/renderer/src/components/TitleBar.vue +++ b/frontend/src/renderer/src/components/TitleBar.vue @@ -1,6 +1,6 @@ @@ -274,6 +583,26 @@ function handleSearch() { overflow: hidden; } +.memory-loading { + display: flex; + align-items: center; + justify-content: center; + gap: 12px; + flex: 1; + color: var(--text-muted); + font-size: 14px; +} + +.spinning { + animation: spin 1s linear infinite; +} + +@keyframes spin { + to { + transform: rotate(360deg); + } +} + .memory-header { display: flex; align-items: center; @@ -322,7 +651,8 @@ function handleSearch() { transition: all 300ms ease-in-out; } -.search-bar:focus-within { +.search-bar:focus-within, +.search-bar.search-expanded { border-color: #8b5cf6; box-shadow: 0 0 0 2px rgba(139, 92, 246, 0.15); } @@ -333,7 +663,7 @@ function handleSearch() { } .search-bar input { - width: 180px; + width: 140px; font-size: 13px; background: transparent; color: var(--text); @@ -343,18 +673,32 @@ function handleSearch() { color: var(--text-muted); } -.search-refresh { +.search-clear-btn, +.search-trigger-btn { + display: flex; + align-items: center; + justify-content: center; + width: 24px; + height: 24px; + border-radius: 6px; color: var(--text-muted); cursor: pointer; - transition: transform 300ms ease-in-out; + transition: all 200ms; } -.search-refresh.spinning { - animation: spin 1s linear infinite; +.search-clear-btn:hover, +.search-trigger-btn:hover { + background: var(--surface-hover); + color: var(--text); } -@keyframes spin { - to { transform: rotate(360deg); } +.search-trigger-btn:disabled { + opacity: 0.4; + cursor: default; +} + +.search-refresh { + color: var(--text-muted); } .h-btn { @@ -375,6 +719,16 @@ function handleSearch() { color: var(--text); } +.h-btn.primary { + color: var(--text); + background: rgba(139, 92, 246, 0.1); + border: 1px solid rgba(139, 92, 246, 0.2); +} + +.h-btn.primary:hover { + background: rgba(139, 92, 246, 0.18); +} + .memory-body { display: flex; flex: 1; @@ -410,8 +764,15 @@ function handleSearch() { } @keyframes card-enter { - from { opacity: 0; transform: translateX(-16px); } - to { opacity: 1; transform: translateX(0); } + from { + opacity: 0; + transform: translateX(-16px); + } + + to { + opacity: 1; + transform: translateX(0); + } } .layer-card:hover { @@ -616,6 +977,30 @@ function handleSearch() { color: var(--text); } +.empty-layer { + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + padding: 40px 20px; + color: var(--text-muted); +} + +.empty-layer svg { + margin-bottom: 12px; + opacity: 0.5; +} + +.empty-layer p { + font-size: 14px; + margin-bottom: 4px; +} + +.empty-hint { + font-size: 12px !important; + opacity: 0.7; +} + .memo-items { display: flex; flex-direction: column; @@ -633,16 +1018,27 @@ function handleSearch() { opacity: 0; animation: memo-in 0.4s cubic-bezier(0.22, 1, 0.36, 1) both; animation-delay: var(--item-delay); + position: relative; } @keyframes memo-in { - from { opacity: 0; transform: translateY(8px); } - to { opacity: 1; transform: translateY(0); } + from { + opacity: 0; + transform: translateY(8px); + } + + to { + opacity: 1; + transform: translateY(0); + } } .memo-item:hover { border-color: var(--border); - transform: translateX(4px); +} + +.memo-item:hover .memo-actions { + opacity: 1; } .memo-dot { @@ -668,7 +1064,7 @@ function handleSearch() { .memo-footer { display: flex; align-items: center; - gap: 10px; + gap: 6px; } .memo-tag { @@ -680,11 +1076,104 @@ function handleSearch() { font-weight: 500; } +.memo-tag.category-tag { + background: rgba(245, 158, 11, 0.1); + color: #b45309; +} + .memo-time { font-size: 11px; color: var(--text-muted); } +.memo-actions { + display: flex; + flex-direction: column; + gap: 4px; + opacity: 0; + transition: opacity 200ms; + flex-shrink: 0; +} + +.memo-action-btn { + display: flex; + align-items: center; + justify-content: center; + width: 26px; + height: 26px; + border-radius: 6px; + color: var(--text-muted); + cursor: pointer; + transition: all 200ms; +} + +.memo-action-btn:hover { + background: var(--surface-hover); + color: var(--lumi-primary); +} + +.memo-action-btn.danger:hover { + background: rgba(244, 63, 94, 0.1); + color: #f43f5e; +} + +.edit-textarea { + width: 100%; + padding: 8px; + border: 1px solid var(--border); + border-radius: 8px; + background: var(--bg); + color: var(--text); + font-size: 13px; + resize: vertical; + font-family: inherit; + outline: none; +} + +.edit-textarea:focus { + border-color: #8b5cf6; +} + +.edit-actions { + display: flex; + gap: 8px; + margin-top: 8px; +} + +.edit-btn { + display: flex; + align-items: center; + gap: 4px; + padding: 4px 10px; + border-radius: 6px; + font-size: 12px; + cursor: pointer; + transition: all 200ms; +} + +.edit-btn.save { + background: rgba(139, 92, 246, 0.1); + color: #8b5cf6; +} + +.edit-btn.save:hover { + background: rgba(139, 92, 246, 0.2); +} + +.edit-btn.save:disabled { + opacity: 0.5; + cursor: default; +} + +.edit-btn.cancel { + background: var(--surface-hover); + color: var(--text-muted); +} + +.edit-btn.cancel:hover { + color: var(--text); +} + .portrait-card { padding: 18px; border-radius: 14px; @@ -730,7 +1219,7 @@ function handleSearch() { margin-bottom: 14px; } -.portrait-interests > svg { +.portrait-interests>svg { color: var(--text-muted); flex-shrink: 0; } @@ -762,9 +1251,192 @@ function handleSearch() { color: var(--lumi-primary); } +.dialog-overlay { + position: fixed; + inset: 0; + background: rgba(0, 0, 0, 0.4); + display: flex; + align-items: center; + justify-content: center; + z-index: 100; +} + +.dialog-card { + width: 460px; + max-width: 90vw; + background: var(--bg); + border-radius: 16px; + box-shadow: var(--shadow-lg); + overflow: hidden; +} + +.dialog-header { + display: flex; + align-items: center; + gap: 8px; + padding: 16px 20px; + border-bottom: 1px solid var(--border); + font-size: 15px; + font-weight: 600; + color: var(--text); +} + +.dialog-close-btn { + margin-left: auto; + display: flex; + align-items: center; + justify-content: center; + width: 28px; + height: 28px; + border-radius: 6px; + color: var(--text-muted); + cursor: pointer; +} + +.dialog-close-btn:hover { + background: var(--surface-hover); +} + +.dialog-body { + padding: 20px; + display: flex; + flex-direction: column; + gap: 14px; +} + +.dialog-textarea { + width: 100%; + padding: 12px; + border: 1px solid var(--border); + border-radius: 10px; + background: var(--surface); + color: var(--text); + font-size: 13px; + resize: none; + font-family: inherit; + outline: none; +} + +.dialog-textarea:focus { + border-color: #8b5cf6; +} + +.dialog-textarea::placeholder { + color: var(--text-muted); +} + +.dialog-category { + display: flex; + align-items: center; + gap: 10px; +} + +.category-label { + font-size: 13px; + color: var(--text-muted); + flex-shrink: 0; +} + +.category-select { + flex: 1; + padding: 8px 12px; + border: 1px solid var(--border); + border-radius: 8px; + background: var(--surface); + color: var(--text); + font-size: 13px; + outline: none; +} + +.dialog-footer { + display: flex; + justify-content: flex-end; + gap: 8px; + padding: 14px 20px; + border-top: 1px solid var(--border); +} + +.dialog-btn { + display: flex; + align-items: center; + gap: 6px; + padding: 8px 16px; + border-radius: 8px; + font-size: 13px; + font-weight: 500; + cursor: pointer; + transition: all 200ms; +} + +.dialog-btn.cancel { + background: var(--surface); + color: var(--text-muted); +} + +.dialog-btn.cancel:hover { + background: var(--surface-hover); + color: var(--text); +} + +.dialog-btn.confirm { + background: rgba(139, 92, 246, 0.1); + color: #8b5cf6; + border: 1px solid rgba(139, 92, 246, 0.2); +} + +.dialog-btn.confirm:hover { + background: rgba(139, 92, 246, 0.2); +} + +.dialog-btn.confirm:disabled { + opacity: 0.5; + cursor: default; +} + +.dialog-fade-enter-active { + animation: fade-in 0.25s ease-out; +} + +.dialog-fade-enter-active .dialog-card { + animation: scale-in 0.3s cubic-bezier(0.22, 1, 0.36, 1); +} + +.dialog-fade-leave-active { + animation: fade-in 0.2s ease-out reverse; +} + +@keyframes fade-in { + from { + opacity: 0; + } + + to { + opacity: 1; + } +} + +@keyframes scale-in { + from { + opacity: 0; + transform: scale(0.92); + } + + to { + opacity: 1; + transform: scale(1); + } +} + @keyframes fade-up { - 0% { opacity: 0; transform: translateY(16px); } - 100% { opacity: 1; transform: translateY(0); } + 0% { + opacity: 0; + transform: translateY(16px); + } + + 100% { + opacity: 1; + transform: translateY(0); + } } .animate-fade-up { @@ -772,8 +1444,15 @@ function handleSearch() { } @keyframes slide-left { - 0% { opacity: 0; transform: translateX(24px); } - 100% { opacity: 1; transform: translateX(0); } + 0% { + opacity: 0; + transform: translateX(24px); + } + + 100% { + opacity: 1; + transform: translateX(0); + } } .animate-slide-left { @@ -784,12 +1463,14 @@ function handleSearch() { .memo-list-leave-active { transition: all 300ms ease-in-out; } + .memo-list-enter-from { opacity: 0; transform: translateY(8px); } + .memo-list-leave-to { opacity: 0; transform: translateX(-12px); } - + \ No newline at end of file diff --git a/frontend/src/renderer/src/views/WorkspaceView.vue b/frontend/src/renderer/src/views/WorkspaceView.vue index c1613c4..c3bd2b8 100644 --- a/frontend/src/renderer/src/views/WorkspaceView.vue +++ b/frontend/src/renderer/src/views/WorkspaceView.vue @@ -25,12 +25,22 @@ import { PanelRightOpen, PanelRightClose, Square, + UploadCloud, + FileText, + Image, + File, + Brain, + Download, } from 'lucide-vue-next' import { useRouter } from 'vue-router' import { useChatStore } from '../stores/chat' import { useAgentStore } from '../stores/agent' import { useModelStore } from '../stores/model' import { useSkillStore } from '../stores/skill' +import { useMemoryStore } from '../stores/memory' +import FileUpload from '../components/FileUpload.vue' +import FilePreview from '../components/FilePreview.vue' +import { useFileUpload } from '../composables/useFileUpload' import { getProviderLogo } from '../config/provider-logos' import { marked } from 'marked' @@ -44,6 +54,12 @@ const chatStore = useChatStore() const agentStore = useAgentStore() const modelStore = useModelStore() const skillStore = useSkillStore() +const memoryStore = useMemoryStore() + +const showMemoryInject = ref(false) + +const { uploadingFile, isUploading, parsedContent, fileType, fileName, uploadAndForward, clearUploadState } = useFileUpload() +const fileUploadRef = ref | null>(null) const inputText = ref('') const messagesContainer = ref(null) @@ -55,12 +71,36 @@ const showSearchPanel = ref(false) const searchQuery = ref('') const searchResults = ref([]) const copiedId = ref(null) +const showReasoning = ref>({}) +const reasoningRefs = ref>({}) +const reasoningScrollRefs = ref(null) const isNearBottom = ref(true) const SCROLL_BOTTOM_THRESHOLD = 120 const showScrollToBottomBtn = ref(false) const isLoadingCurrentConv = computed(() => chatStore.isLoadingCurrentConversation) let resizeObserver: ResizeObserver | null = null +const showGlobalDropOverlay = ref(false) +let globalDragCounter = 0 +let dragLeaveTimer: ReturnType | null = null + +const showFilePreview = ref(false) +const previewFile = ref<{ name: string; type?: string; content?: string } | null>(null) + +const toastMessage = ref('') +const showToast = ref(false) +let toastTimer: ReturnType | null = null + +const displayToast = (msg: string) => { + if (toastTimer) clearTimeout(toastTimer) + toastMessage.value = msg + showToast.value = true + toastTimer = setTimeout(() => { + showToast.value = false + toastTimer = null + }, 3000) +} + const messages = computed(() => chatStore.messages) const isStreaming = computed(() => chatStore.isStreaming) const isBackendReady = computed(() => chatStore.isBackendReady) @@ -126,13 +166,28 @@ const selectModel = (providerId: string, modelId: string) => { showModelDropdown.value = false } +const canSend = computed(() => { + if (!isBackendReady.value) return false + if (isUploading.value) return false + return inputText.value.trim().length > 0 || !!parsedContent.value +}) + const sendMessage = async () => { - if (!inputText.value.trim()) return - if (!isBackendReady.value) return + if (!canSend.value) return + + let content = inputText.value.trim() + const fileContent = parsedContent.value + const currentFileName = fileName.value + const currentFileType = fileType.value + + if (!content && fileContent) { + content = '请帮我分析上传的文件' + } - const content = inputText.value inputText.value = '' resetTextareaHeight() + clearUploadState() + fileUploadRef.value?.clearUploadState() const agent = agentStore.activeAgent const resolved = modelStore.resolveModel @@ -147,6 +202,12 @@ const sendMessage = async () => { if (agent?.systemPrompt) options.systemPrompt = agent.systemPrompt if (agent?.id) options.agentId = agent.id + if (fileContent) { + options.fileContent = fileContent + options.fileType = currentFileType + options.fileName = currentFileName + } + isNearBottom.value = true await chatStore.sendMessage(content, options) await nextTick() @@ -204,6 +265,22 @@ const renderMarkdown = (text: string): string => { return marked.parse(text) as string } +const getFileIcon = (fileType?: string) => { + if (!fileType) return File + if (fileType === 'image') return Image + return FileText +} + +const openFilePreview = (file: { name: string; type?: string; content?: string }) => { + previewFile.value = { name: file.name, type: file.type, content: file.content } + showFilePreview.value = true +} + +const closeFilePreview = () => { + showFilePreview.value = false + previewFile.value = null +} + const contextUsage = computed(() => { // 修复:倒序渲染问题 - 改用正序查找最后一条完成的助手消息 const lastAssistantMsg = messages.value.findLast(m => m.role === 'assistant' && m.done) @@ -215,6 +292,46 @@ const contextPercent = computed(() => { return Math.min(100, Math.round((contextUsage.value.totalTokens / modelStore.modelConfig.defaultMaxTokens) * 100)) }) +const toggleReasoning = (msgId: string) => { + showReasoning.value = { + ...showReasoning.value, + [msgId]: !showReasoning.value[msgId] + } +} + +const lastAssistantMsg = computed(() => { + const msgs = messages.value + if (msgs.length === 0) return null + const last = msgs[msgs.length - 1] + return last && last.role === 'assistant' ? last : null +}) + +const reasoningIsRunning = computed(() => { + const msg = lastAssistantMsg.value + if (!msg) return false + return !msg.done && (!msg.content || msg.content.length === 0) && (msg.reasoningContent !== undefined) +}) + +watch(() => messages.value, async (msgs) => { + for (const msg of msgs) { + if (msg.role !== 'assistant') continue + if (msg.content && msg.content.length > 0 && showReasoning.value[msg.id] === undefined) { + showReasoning.value = { ...showReasoning.value, [msg.id]: false } + } + } + await nextTick() + // 对所有正在推理的消息自动滚动到底部 + const scrollEls = reasoningScrollRefs.value + if (scrollEls) { + const els = Array.isArray(scrollEls) ? scrollEls : [scrollEls] + for (const el of els) { + if (el && el.scrollHeight > el.clientHeight) { + el.scrollTop = el.scrollHeight + } + } + } +}, { deep: false, immediate: true }) + const copyMessage = async (msgId: string, content: string) => { try { await navigator.clipboard.writeText(content) @@ -223,19 +340,121 @@ const copyMessage = async (msgId: string, content: string) => { } catch {} } +const handleGlobalDragEnter = (e: DragEvent) => { + if (e.dataTransfer?.types.includes('Files')) { + e.preventDefault() + if (dragLeaveTimer) { + clearTimeout(dragLeaveTimer) + dragLeaveTimer = null + } + globalDragCounter++ + showGlobalDropOverlay.value = true + } +} + +const handleGlobalDragOver = (e: DragEvent) => { + if (e.dataTransfer?.types.includes('Files')) { + e.preventDefault() + if (dragLeaveTimer) { + clearTimeout(dragLeaveTimer) + dragLeaveTimer = null + } + showGlobalDropOverlay.value = true + } +} + +const handleGlobalDragLeave = (e: DragEvent) => { + if (e.dataTransfer?.types.includes('Files')) { + e.preventDefault() + globalDragCounter-- + if (globalDragCounter <= 0) { + dragLeaveTimer = setTimeout(() => { + showGlobalDropOverlay.value = false + globalDragCounter = 0 + }, 100) + } + } +} + +const allowedExtensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp', '.pdf', '.docx', '.doc', '.txt', '.md', '.csv', '.json', '.xml', '.html', '.css', '.js', '.py', '.java', '.cpp', '.c', '.h', '.go', '.rs', '.ts', '.sql', '.yaml', '.yml'] + +const isFileAllowed = (fileName: string): boolean => { + const ext = fileName.toLowerCase().substring(fileName.lastIndexOf('.')) + return allowedExtensions.includes(ext) +} + +const handleGlobalDrop = async (e: DragEvent) => { + e.preventDefault() + showGlobalDropOverlay.value = false + globalDragCounter = 0 + if (dragLeaveTimer) { + clearTimeout(dragLeaveTimer) + dragLeaveTimer = null + } + const files = e.dataTransfer?.files + if (files && files.length > 0 && !isUploading.value) { + const file = files[0] + if (isFileAllowed(file.name)) { + await uploadAndForward(file) + } else { + displayToast(`不支持的文件类型: ${file.name}`) + } + } +} + +const handlePaste = async (e: ClipboardEvent) => { + const items = e.clipboardData?.items + if (!items) return + for (let i = 0; i < items.length; i++) { + if (items[i].kind === 'file') { + const file = items[i].getAsFile() + if (file && !isUploading.value) { + if (isFileAllowed(file.name)) { + e.preventDefault() + await uploadAndForward(file) + return + } else { + displayToast(`不支持的文件类型: ${file.name}`) + } + } + } + } +} + const formatTime = (dateStr: string) => { - // 修复:历史记录实时更新 - 处理Invalid Date问题 if (!dateStr || dateStr === 'undefined' || dateStr === 'null') { return '刚刚' } try { - const d = new Date(dateStr) + let d: Date + const numDate = Number(dateStr) + if (!isNaN(numDate)) { + d = new Date(numDate) + } else { + d = new Date(dateStr) + } + if (isNaN(d.getTime())) { return '刚刚' } const now = new Date() + const diffMs = now.getTime() - d.getTime() + const diffMins = Math.floor(diffMs / 60000) + const diffHours = Math.floor(diffMs / 3600000) + const diffDays = Math.floor(diffMs / 86400000) + + if (diffMins < 1) { + return '刚刚' + } else if (diffMins < 60) { + return `${diffMins}分钟前` + } else if (diffHours < 24) { + return `${diffHours}小时前` + } else if (diffDays < 7) { + return `${diffDays}天前` + } + const isToday = d.toDateString() === now.toDateString() if (isToday) { @@ -304,6 +523,37 @@ const handleClickOutsideModel = (e: MouseEvent) => { } } +async function injectMemoryToInput() { + showMemoryInject.value = true + try { + const result = await memoryStore.fetchInjectionContent(agentStore.activeAgent?.id) + if (result.has_memory && result.content) { + inputText.value = `\n\n---\n系统已注入以下用户记忆,请参考:\n${result.content}\n---\n\n${inputText.value}` + } + } finally { + showMemoryInject.value = false + } +} + +function handleChatTrigger(event: CustomEvent) { + if (event.detail?.message) { + inputText.value = event.detail.message + } +} + +function handleMemoryChatTrigger(event: CustomEvent) { + const text = event.detail?.text + if (text) { + inputText.value = `关于我之前提到的「${text.slice(0, 80)}」,请帮我进一步分析。` + } +} + +function handleMemoryChatTriggerDirect(text: string) { + inputText.value = `关于我之前提到的「${text.slice(0, 80)}」,请帮我进一步分析。` +} + +(window as any).__memoryChatTrigger = handleMemoryChatTriggerDirect + onMounted(async () => { await chatStore.checkBackend() if (chatStore.isBackendReady) { @@ -317,12 +567,26 @@ onMounted(async () => { ]) } document.addEventListener('click', handleClickOutsideModel) + document.addEventListener('dragenter', handleGlobalDragEnter) + document.addEventListener('dragover', handleGlobalDragOver) + document.addEventListener('dragleave', handleGlobalDragLeave) + document.addEventListener('drop', handleGlobalDrop) + document.addEventListener('paste', handlePaste) + window.addEventListener('luominest:chat-trigger', handleChatTrigger as EventListener) + window.addEventListener('luominest:memory-chat-trigger', handleMemoryChatTrigger as EventListener) nextTick(() => setupResizeObserver()) }) onBeforeUnmount(() => { resizeObserver?.disconnect() document.removeEventListener('click', handleClickOutsideModel) + document.removeEventListener('dragenter', handleGlobalDragEnter) + document.removeEventListener('dragover', handleGlobalDragOver) + document.removeEventListener('dragleave', handleGlobalDragLeave) + document.removeEventListener('drop', handleGlobalDrop) + document.removeEventListener('paste', handlePaste) + window.removeEventListener('luominest:chat-trigger', handleChatTrigger as EventListener) + window.removeEventListener('luominest:memory-chat-trigger', handleMemoryChatTrigger as EventListener) }) @@ -403,11 +667,53 @@ onBeforeUnmount(() => {
    {{ agentStore.activeAgent?.name || 'LuomiNest' }}
    -
    -
    {{ msg.content }}
    -
    - - 正在分析问题... +
    +
    + + + + + + + + + +
    +
    + {{ msg.reasoningContent || '...' }} +
    +
    +
    +
    + + 已中断 + +
    +
    + 已中断 +
    +
    + {{ msg.content }} +
    +
    + + {{ file.name }} + +
    +
    @@ -455,6 +761,7 @@ onBeforeUnmount(() => {
    +