diff --git a/bridges/wechat.py b/bridges/wechat.py index f83ae83..7115abd 100644 --- a/bridges/wechat.py +++ b/bridges/wechat.py @@ -268,6 +268,7 @@ def _wx_qr_login(config: dict, bot_type: str = _ILINK_DEFAULT_BOT_TYPE, def _wx_poll_loop(token: str, base_url: str, config: dict) -> str: """Returns "stopped", "auth_error", or raises on unexpected fatal error.""" from tools import _wx_thread_local + from bridges import wechat_smart_reply as _sr session_ctx = runtime.get_session_ctx(config.get("_session_id", "default")) run_query_cb = session_ctx.run_query sync_buf = "" @@ -275,6 +276,15 @@ def _wx_poll_loop(token: str, base_url: str, config: dict) -> str: session_ctx.wx_send = lambda uid, txt: _wx_send(uid, txt, config) + # Smart-reply panel store (SQLite-backed; falls back to in-memory) and + # contacts loader, lifecycles bound to the poll loop. + _smart_store = _sr.make_store( + timeout_s=float(config.get("wechat_smart_reply_timeout_s", + _sr.DEFAULT_TIMEOUT_S)), + ) + _smart_store.start_janitor() + _smart_contacts = _sr.ContactsStore() + while not _wechat_stop.is_set(): try: result = _wx_get_updates(base_url, token, sync_buf) @@ -346,6 +356,20 @@ def _wx_poll_loop(token: str, base_url: str, config: dict) -> str: evt.set() continue + # ── Smart-reply: filehelper input routes panel choice ────── + # Only consume the message if there's an active panel and + # the user is responding to it; otherwise fall through so + # they can still use !jobs / etc. from filehelper. + if _sr.is_filehelper(from_uid): + consumed = _sr.handle_filehelper_message( + text, _smart_store, + send_to_target=lambda uid, txt: _wx_send(uid, txt, config), + send_to_filehelper=lambda txt: _wx_send(_sr._FILEHELPER_UID, txt, config), + ) + if consumed: + print(clr(f"\n ✓ smart-reply panel resolved", "dim")) + continue + print(clr(f"\n 📩 WeChat [{from_uid[:8]}]: {text}", "cyan")) # ── Interactive PTY session ──────────────────────────────── @@ -560,6 +584,27 @@ def _wx_terminal(cmd, uid, skey): _pending_evt.set() continue + # ── Smart-reply: whitelisted contact → draft, don't auto-reply ─ + if _sr.is_smart_reply_target(from_uid, config, text=text): + label = (msg.get("from_user_nickname") + or msg.get("from_username") + or from_uid[:8]) + triggered = _sr.trigger_smart_reply( + target_uid=from_uid, + target_label=str(label), + message=text, + store=_smart_store, + config=config, + send_to_filehelper=lambda txt: _wx_send(_sr._FILEHELPER_UID, txt, config), + contacts=_smart_contacts, + ) + if triggered: + print(clr(f" ↳ smart-reply panel sent to filehelper", "dim")) + continue + # Generation failed → fall through to normal dispatch so + # the user still gets *some* response. + print(clr(f" ⚠ smart-reply candidate generation failed; falling back to auto-reply", "yellow")) + # ── Claude query: create job, queue if busy, else run now ── job = _jobs.create(text, source="wechat") diff --git a/bridges/wechat_smart_reply.py b/bridges/wechat_smart_reply.py new file mode 100644 index 0000000..5eb02be --- /dev/null +++ b/bridges/wechat_smart_reply.py @@ -0,0 +1,473 @@ +"""WeChat smart-reply panel — AI drafts, user approves on phone via 文件传输助手. + +Flow: + 1. Inbound message from a whitelisted contact arrives. + 2. Generate 3 candidate replies via the auxiliary cheap model, conditioned on: + * the contact's relationship/notes from ~/.cheetahclaws/wx_contacts.json + * the user's recent confirmed replies (style mimicking) + 3. Send a panel to filehelper, tagged with a 2-letter ID: + 💬 [AA] 张三 → "周末有空吗" + [1] 有的,周六下午行 + [2] 周末出差,下周可以吗 + [3] 在忙,晚点回你 + 回 1/2/3 发送 · 直接打字自定义 · x 跳过 · q 看队列 + 4. Filehelper input is interpreted against the active queue: + - "1" / "2" / "3" → send candidate from latest active panel + - "x" → skip latest active panel + - "" → send freeform reply for latest active panel + - "AA 1" / "AA x" / "AA hi" → address panel by ID explicitly + - "q" / "queue" → list pending panels + 5. Confirmed sends are appended to wx_reply_history and feed style mimicking. + 6. Panels expire after wechat_smart_reply_timeout_s (default 5 min). + The store's janitor sweeps expired rows. + +Group rules: + * Group chats off by default. + * `wechat_smart_reply_groups: true` enables them. + * `wechat_smart_reply_groups_at_only: true` further restricts groups to + messages that contain @. + +Storage: + * Panels + reply history persist in SQLite (~/.cheetahclaws/wx_smart_reply.db). + * Contacts persist in JSON (~/.cheetahclaws/wx_contacts.json). + * SQLite init failure auto-falls-back to in-memory; nothing crashes. +""" +from __future__ import annotations + +import re +import time +from dataclasses import dataclass +from typing import Callable, Optional + +from .wechat_smart_reply_store import ( + Contact, + ContactsStore, + DEFAULT_TIMEOUT_S, + InMemoryStore, + PendingPanel, + ReplyHistoryEntry, + SqliteStore, + make_store, + n_to_id, +) + +# Re-export so callers (including tests) get a single import surface. +__all__ = [ + "Contact", "ContactsStore", "PendingPanel", "ReplyHistoryEntry", + "InMemoryStore", "SqliteStore", "make_store", "n_to_id", + "DEFAULT_TIMEOUT_S", + "is_filehelper", "is_group", "is_smart_reply_target", + "ParsedAction", "parse_filehelper_input", + "format_panel", "format_queue", + "generate_candidates", + "trigger_smart_reply", "handle_filehelper_message", +] + + +# ── Identity heuristics ──────────────────────────────────────────────────── + +_FILEHELPER_UID = "filehelper" +_GROUP_SUFFIX = "@chatroom" + + +def is_filehelper(uid: str) -> bool: + return uid == _FILEHELPER_UID + + +def is_group(uid: str) -> bool: + return uid.endswith(_GROUP_SUFFIX) + + +def _matches_at_mention(text: str, nickname: str) -> bool: + """True if the text contains `@` (with word-ish boundary). + + WeChat clients typically render group @-mentions as `@` followed + by a space or end-of-message; we accept that or a CJK boundary. + """ + if not nickname: + return False + # Allow nickname plus a trailing space, end-of-string, or punctuation. + pattern = r"@" + re.escape(nickname) + r"(\s|$|[,,。.!!??::])" + return bool(re.search(pattern, text)) + + +def is_smart_reply_target(uid: str, config: dict, *, text: str = "") -> bool: + """Return True iff a message from ``uid`` should go through smart reply. + + Rules (in order): + * feature flag off → False + * filehelper → False (would loop) + * group + groups disabled → False + * group + groups_at_only + no @ in text → False + * whitelist set and uid not in it → False + * otherwise → True + """ + if not config.get("wechat_smart_reply", False): + return False + if is_filehelper(uid): + return False + if is_group(uid): + if not config.get("wechat_smart_reply_groups", False): + return False + if config.get("wechat_smart_reply_groups_at_only", False): + nickname = (config.get("wechat_self_nickname") or "").strip() + if not nickname or not _matches_at_mention(text, nickname): + return False + whitelist = config.get("wechat_smart_reply_whitelist") or [] + if whitelist and uid not in whitelist: + return False + return True + + +# ── Filehelper input parsing ────────────────────────────────────────────── + +# A panel ID is exactly 2 uppercase letters at the start of input, +# optionally followed by space + payload. +_PANEL_ID_RE = re.compile(r"^([A-Z]{2})(?:\s+(.*))?$") + + +@dataclass(frozen=True) +class ParsedAction: + kind: str # "send" | "skip" | "list" | "noop" + panel_id: Optional[str] # explicit ID if user prefixed with "AA"; else None + text: Optional[str] # send-text (for kind == "send") + + +def parse_filehelper_input(text: str, store) -> ParsedAction: + """Interpret a filehelper-incoming message against the active panel queue. + + Behaviour: + * "q" / "queue" → kind="list" + * " " → kind=send/skip/list, panel_id set + * "1" / "2" / "3" → send candidate from latest active panel + * "x" / "skip" / "跳过" → skip latest active panel + * any other text → freeform send for latest active panel + Returns ``noop`` if there's no active panel and the input wasn't ``q``. + """ + s = text.strip() + if not s: + return ParsedAction("noop", None, None) + + low = s.lower() + if low in ("q", "queue", "/q", "/queue", "队列"): + return ParsedAction("list", None, None) + + # Explicit panel-ID prefix: "AA 1", "AA x", "AA hello", or "AA" alone (=list) + m = _PANEL_ID_RE.match(s) + if m: + pid = m.group(1) + rest = (m.group(2) or "").strip() + return _classify_choice(rest, panel=store.get_by_id(pid), panel_id=pid) + + # No panel-ID prefix → applies to the latest active panel + active = store.take_active() if store is not None else None + return _classify_choice(s, panel=active, panel_id=None) + + +def _classify_choice(payload: str, panel: Optional[PendingPanel], + panel_id: Optional[str]) -> ParsedAction: + if panel is None: + return ParsedAction("noop", panel_id, None) + + s = payload.strip() + low = s.lower() + + if low in ("", "show", "preview"): + # "AA" alone means "show me this panel" — surfaced as `list` with id + # so the caller can reformat or ignore. + return ParsedAction("list", panel_id, None) + if low in ("x", "skip", "/skip", "/x", "跳过", "不回"): + return ParsedAction("skip", panel_id or panel.panel_id, None) + if s in ("1", "2", "3"): + idx = int(s) - 1 + if 0 <= idx < len(panel.candidates): + return ParsedAction("send", panel_id or panel.panel_id, + panel.candidates[idx]) + return ParsedAction("send", panel_id or panel.panel_id, s) + + +# ── Panel + queue formatting ────────────────────────────────────────────── + + +def format_panel(panel: PendingPanel) -> str: + label = panel.target_label or panel.target_uid[:8] + msg = panel.message[:200] + if len(panel.message) > 200: + msg += "…" + lines = [f"💬 [{panel.panel_id}] {label} → 「{msg}」", ""] + for i, cand in enumerate(panel.candidates, start=1): + lines.append(f"[{i}] {cand}") + lines.append("") + lines.append("回 1/2/3 发送 · 直接打字自定义 · x 跳过 · q 看队列") + return "\n".join(lines) + + +def format_queue(panels: list[PendingPanel]) -> str: + """Render the pending-panel queue for filehelper. + + Sorted oldest-first so the user sees what's been waiting longest. + """ + if not panels: + return "📋 当前没有待处理的消息" + now = time.time() + lines = ["📋 待处理 (oldest first):", ""] + for p in panels: + age = max(0, int(now - p.created_at)) + ttl = max(0, int(p.expires_at - now)) + label = p.target_label or p.target_uid[:8] + msg_preview = p.message[:40] + if len(p.message) > 40: + msg_preview += "…" + lines.append( + f" [{p.panel_id}] {label} ({_fmt_secs(age)}前 · 还剩 {_fmt_secs(ttl)})" + ) + lines.append(f" 「{msg_preview}」") + lines.append("") + lines.append("发 1/2/3/x/freeform 处理 · 例如 AA 2") + return "\n".join(lines) + + +def _fmt_secs(s: int) -> str: + if s < 60: + return f"{s}s" + if s < 3600: + return f"{s // 60}m" + return f"{s // 3600}h{(s % 3600) // 60}m" + + +# ── Candidate generation ────────────────────────────────────────────────── + + +def _build_prompt(message: str, sender_label: str, + contact: Optional[Contact], + history: list[ReplyHistoryEntry]) -> str: + parts: list[str] = ["你是用户的微信回复助手。用户刚收到一条消息," + "你需要起草 3 个简短自然的候选回复。"] + + if contact and (contact.relationship or contact.notes): + ctx_lines = [f"\n关于发件人 {contact.label or sender_label}:"] + if contact.relationship: + ctx_lines.append(f"- 关系: {contact.relationship}") + if contact.notes: + ctx_lines.append(f"- 备注: {contact.notes}") + parts.append("\n".join(ctx_lines)) + + if history: + examples = [] + for h in history[:10]: + txt = (h.text or "").strip() + if txt: + examples.append(f"- {txt}") + if examples: + parts.append("\n用户最近发出的几条回复(请模仿这种语气和长度," + "不要照抄内容):\n" + "\n".join(examples)) + + parts.append(f"\n收到的消息(来自 {sender_label}):\n{message[:500]}") + + parts.append( + "\n要求:\n" + "- 每条回复 ≤ 30 字\n" + "- 语气:日常、自然、像真人随手回的\n" + "- 3 条之间风格略有差异(例如:肯定 / 委婉拒绝 / 模糊延后)\n" + "- 不要使用 emoji 或复杂标点\n" + "- 不要解释你的选择,只输出回复\n" + "\n格式(严格遵守,每行一条,开头数字+点+空格):\n" + "1. <回复一>\n" + "2. <回复二>\n" + "3. <回复三>" + ) + + return "\n".join(parts) + + +_LIST_RE = re.compile(r"^\s*[1-3][\.\)、]\s*(.+)$") + + +def _parse_candidates(text: str) -> list[str]: + out: list[str] = [] + for line in text.splitlines(): + m = _LIST_RE.match(line) + if m: + cand = m.group(1).strip().rstrip("。,,.!?!?") + if cand: + out.append(cand) + return out[:3] + + +def generate_candidates( + message: str, + sender_label: str, + config: dict, + *, + contact: Optional[Contact] = None, + history: Optional[list[ReplyHistoryEntry]] = None, + _stream_fn: Optional[Callable] = None, +) -> list[str]: + """Produce 3 candidate replies via the auxiliary cheap model. + + Returns ``[]`` on any failure — caller falls back or skips. + """ + prompt = _build_prompt(message, sender_label, contact, history or []) + + if _stream_fn is None: + try: + import providers + from auxiliary import get_auxiliary_model + except Exception: + return [] + _stream_fn = providers.stream + model = get_auxiliary_model(config) + else: + model = config.get("auxiliary_model") or "test-model" + + try: + chunks = [] + for event in _stream_fn( + model=model, + system="只输出 3 个候选回复,按要求的格式。不要任何额外说明。", + messages=[{"role": "user", "content": prompt}], + tool_schemas=[], + config={**config, "max_tokens": 200, "thinking": False}, + ): + t = getattr(event, "text", None) + if t: + chunks.append(t) + return _parse_candidates("".join(chunks)) + except Exception: + return [] + + +# ── Panel constructor ────────────────────────────────────────────────────── + + +def make_panel(target_uid: str, target_label: str, message: str, + candidates: list[str], *, + panel_id: str, + timeout_s: float = DEFAULT_TIMEOUT_S) -> PendingPanel: + now = time.time() + return PendingPanel( + panel_id=panel_id, + target_uid=target_uid, + target_label=target_label, + message=message, + candidates=candidates, + created_at=now, + expires_at=now + timeout_s, + ) + + +# ── High-level entry points ──────────────────────────────────────────────── + + +def trigger_smart_reply( + target_uid: str, + target_label: str, + message: str, + store, + config: dict, + *, + send_to_filehelper: Callable[[str], None], + contacts: Optional[ContactsStore] = None, + generate_fn: Optional[Callable] = None, +) -> bool: + """Generate candidates, store the panel, push it to filehelper. + + Returns True on success, False if generation failed. + """ + contact = contacts.get(target_uid) if contacts else None + history = store.recent_replies(n=10, exclude_uid=target_uid) \ + if hasattr(store, "recent_replies") else [] + gen = generate_fn or generate_candidates + candidates = gen( + message, target_label, config, + contact=contact, history=history, + ) + if not candidates: + return False + if len(candidates) < 3: + # Pad with conservative fallbacks so the panel is always 3 wide. + for f in ("好", "稍后回你", "我看一下"): + if len(candidates) >= 3: + break + if f not in candidates: + candidates.append(f) + timeout_s = float(config.get("wechat_smart_reply_timeout_s", DEFAULT_TIMEOUT_S)) + pid = store.assign_next_id() + panel = make_panel(target_uid, target_label, message, candidates[:3], + panel_id=pid, timeout_s=timeout_s) + store.put(panel) + send_to_filehelper(format_panel(panel)) + return True + + +def handle_filehelper_message( + text: str, + store, + *, + send_to_target: Callable[[str, str], None], + send_to_filehelper: Callable[[str], None], +) -> bool: + """Route a filehelper-incoming message against the active panel queue. + + Returns True if the message was consumed (caller should NOT pass it on + to the agent), False if the input wasn't a smart-reply action and + should fall through to the normal bridge dispatch. + """ + action = parse_filehelper_input(text, store) + + if action.kind == "noop": + return False + + if action.kind == "list": + if action.panel_id: + # "AA" alone — show that one panel + p = store.get_by_id(action.panel_id) + if p is None: + send_to_filehelper(f"⚠ [{action.panel_id}] 已过期或不存在") + return True + send_to_filehelper(format_panel(p)) + return True + send_to_filehelper(format_queue(store.list_active())) + return True + + # send / skip always have a panel_id by this point — parse_filehelper_input + # fills it in either from the explicit ID prefix or from the latest-active + # panel. Defensive None handling kept for forward-compat. + if not action.panel_id: + return False + panel = store.consume_by_id(action.panel_id) + if panel is None: + send_to_filehelper(f"⚠ [{action.panel_id}] 已过期或不存在") + return True + + if action.kind == "skip": + send_to_filehelper(f"⏭ 已跳过 [{panel.panel_id}] {panel.target_label}") + return True + + if action.kind == "send" and action.text: + send_to_target(panel.target_uid, action.text) + if hasattr(store, "write_reply"): + try: + source = _classify_source(action.text, panel.candidates) + store.write_reply( + to_uid=panel.target_uid, + to_label=panel.target_label, + text=action.text, + source=source, + ) + except Exception: + pass + send_to_filehelper( + f"✓ 已发送给 [{panel.panel_id}] {panel.target_label}:{action.text[:60]}" + ) + return True + + return False + + +def _classify_source(text: str, candidates: list[str]) -> str: + """Tag a confirmed reply as candidate_N (matched) or 'freeform'.""" + for i, c in enumerate(candidates, start=1): + if text == c: + return f"candidate_{i}" + return "freeform" diff --git a/bridges/wechat_smart_reply_store.py b/bridges/wechat_smart_reply_store.py new file mode 100644 index 0000000..bed181c --- /dev/null +++ b/bridges/wechat_smart_reply_store.py @@ -0,0 +1,595 @@ +"""WeChat smart-reply persistent storage. + +Three concerns share a single SQLite file at +``~/.cheetahclaws/wx_smart_reply.db`` so a bridge restart doesn't drop +panels mid-conversation, and so style mimicking has past replies to +draw from after a daemon recycle: + +- ``wx_panels`` — pending PermissionPanels (panel_id, target, + candidates, expires_at). Replaces the + in-memory ring; janitor sweeps on access. +- ``wx_reply_history`` — every confirmed send; smart-reply prompt + reads the last N rows as style examples. +- ``wx_contacts`` — per-uid relationship/notes; mirrored from + ``~/.cheetahclaws/wx_contacts.json`` (JSON + stays the source of truth because it's + user-edited; the SQLite mirror is just a + cache for fast lookup). + +Schema is additive and idempotent. No migrations needed for v1. + +Threading: SQLite ``check_same_thread=False`` plus a connection lock so +the poll loop, janitor, and any future RPC handler can share one +connection. Each method is short — no long transactions — so the lock +contention is negligible. +""" +from __future__ import annotations + +import json +import sqlite3 +import threading +import time +from contextlib import contextmanager +from dataclasses import dataclass, field +from pathlib import Path +from typing import Iterable, Iterator, Optional + +# ── Constants ────────────────────────────────────────────────────────────── + +DEFAULT_DB_PATH = Path.home() / ".cheetahclaws" / "wx_smart_reply.db" +DEFAULT_CONTACTS_JSON = Path.home() / ".cheetahclaws" / "wx_contacts.json" +DEFAULT_TIMEOUT_S = 5 * 60 +JANITOR_TICK_S = 30.0 +HISTORY_KEEP_DAYS = 30 # rows older than this are pruned by the janitor + +_SCHEMA = [ + """CREATE TABLE IF NOT EXISTS wx_panels ( + panel_id TEXT PRIMARY KEY, + target_uid TEXT NOT NULL, + target_label TEXT NOT NULL, + message TEXT NOT NULL, + candidates TEXT NOT NULL, + created_at REAL NOT NULL, + expires_at REAL NOT NULL + )""", + "CREATE INDEX IF NOT EXISTS idx_wx_panels_expires ON wx_panels(expires_at)", + "CREATE INDEX IF NOT EXISTS idx_wx_panels_target ON wx_panels(target_uid)", + """CREATE TABLE IF NOT EXISTS wx_reply_history ( + ts REAL NOT NULL, + to_uid TEXT NOT NULL, + to_label TEXT, + text TEXT NOT NULL, + source TEXT + )""", + "CREATE INDEX IF NOT EXISTS idx_wx_history_ts ON wx_reply_history(ts)", + """CREATE TABLE IF NOT EXISTS wx_id_counter ( + id INTEGER PRIMARY KEY CHECK (id = 1), + n INTEGER NOT NULL + )""", + "INSERT OR IGNORE INTO wx_id_counter (id, n) VALUES (1, 0)", +] + + +# ── Domain types ─────────────────────────────────────────────────────────── + + +@dataclass +class PendingPanel: + panel_id: str # 2-letter monotonic ID, e.g. "AA" + target_uid: str + target_label: str + message: str + candidates: list[str] + created_at: float + expires_at: float + + +@dataclass(frozen=True) +class ReplyHistoryEntry: + ts: float + to_uid: str + to_label: Optional[str] + text: str + source: Optional[str] + + +@dataclass(frozen=True) +class Contact: + uid: str + label: Optional[str] = None + relationship: Optional[str] = None + notes: Optional[str] = None + + +# ── Panel-ID assignment (AA..ZZ rolling) ─────────────────────────────────── + +def n_to_id(n: int) -> str: + """Map a non-negative integer to a 2-letter base-26 ID (AA..ZZ). + + Wraps every 676 panels. In practice users never queue >100 + simultaneously, so collisions across active panels are not a concern; + if one ever happened, the SQLite primary key would refuse the insert + and the bridge would log a warning and assign the next id. + """ + n = n % (26 * 26) + a, b = divmod(n, 26) + return chr(ord("A") + a) + chr(ord("A") + b) + + +# ── SQLite-backed store ──────────────────────────────────────────────────── + + +class SqliteStore: + """Panel + reply-history persistence backed by SQLite. + + Public surface mirrors the in-memory ``InMemoryStore`` so the bridge + can swap implementations without further changes. + """ + + def __init__(self, db_path: Optional[Path] = None, + *, timeout_s: float = DEFAULT_TIMEOUT_S) -> None: + self.db_path = Path(db_path) if db_path is not None else DEFAULT_DB_PATH + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._timeout_s = timeout_s + # check_same_thread=False because the poll loop and janitor are + # different threads. We protect with our own lock. + self._conn = sqlite3.connect(str(self.db_path), check_same_thread=False) + self._conn.row_factory = sqlite3.Row + self._lock = threading.Lock() + self._init_schema() + self._stop = threading.Event() + self._janitor: Optional[threading.Thread] = None + + def _init_schema(self) -> None: + with self._txn() as cur: + for stmt in _SCHEMA: + cur.execute(stmt) + + @contextmanager + def _txn(self): + with self._lock: + cur = self._conn.cursor() + try: + yield cur + self._conn.commit() + except Exception: + self._conn.rollback() + raise + finally: + cur.close() + + # ── Janitor lifecycle ───────────────────────────────────────────── + + def start_janitor(self) -> None: + if self._janitor is not None: + return + self._janitor = threading.Thread( + target=self._janitor_loop, name="wx-smart-reply-janitor", daemon=True, + ) + self._janitor.start() + + def stop(self) -> None: + self._stop.set() + if self._janitor is not None: + self._janitor.join(timeout=2.0) + try: + self._conn.close() + except Exception: + pass + + def _janitor_loop(self) -> None: + while not self._stop.wait(JANITOR_TICK_S): + try: + self.sweep_expired() + self.prune_history(older_than_days=HISTORY_KEEP_DAYS) + except Exception: + # Janitor failures are non-fatal; the next tick will retry. + pass + + # ── Panel ID generation ─────────────────────────────────────────── + + def assign_next_id(self) -> str: + with self._txn() as cur: + cur.execute("UPDATE wx_id_counter SET n = n + 1 WHERE id = 1") + cur.execute("SELECT n FROM wx_id_counter WHERE id = 1") + n = cur.fetchone()["n"] + return n_to_id(n - 1) + + # ── Panel operations ────────────────────────────────────────────── + + def put(self, panel: PendingPanel) -> None: + with self._txn() as cur: + cur.execute( + """INSERT OR REPLACE INTO wx_panels + (panel_id, target_uid, target_label, message, + candidates, created_at, expires_at) + VALUES (?, ?, ?, ?, ?, ?, ?)""", + (panel.panel_id, panel.target_uid, panel.target_label, + panel.message, json.dumps(panel.candidates, ensure_ascii=False), + panel.created_at, panel.expires_at), + ) + + def take_active(self) -> Optional[PendingPanel]: + """Return the most-recently-created non-expired panel, or None.""" + now = time.time() + with self._txn() as cur: + cur.execute( + """SELECT * FROM wx_panels + WHERE expires_at > ? + ORDER BY created_at DESC LIMIT 1""", + (now,), + ) + row = cur.fetchone() + return _row_to_panel(row) if row else None + + def get_by_id(self, panel_id: str) -> Optional[PendingPanel]: + now = time.time() + with self._txn() as cur: + cur.execute( + """SELECT * FROM wx_panels + WHERE panel_id = ? AND expires_at > ?""", + (panel_id, now), + ) + row = cur.fetchone() + return _row_to_panel(row) if row else None + + def list_active(self) -> list[PendingPanel]: + now = time.time() + with self._txn() as cur: + cur.execute( + """SELECT * FROM wx_panels + WHERE expires_at > ? + ORDER BY created_at ASC""", + (now,), + ) + rows = cur.fetchall() + return [_row_to_panel(r) for r in rows] + + def consume(self, target_uid: str) -> Optional[PendingPanel]: + """Remove and return the active panel for this target_uid, if any.""" + now = time.time() + with self._txn() as cur: + cur.execute( + """SELECT * FROM wx_panels + WHERE target_uid = ? AND expires_at > ? + ORDER BY created_at DESC LIMIT 1""", + (target_uid, now), + ) + row = cur.fetchone() + if row is None: + return None + cur.execute("DELETE FROM wx_panels WHERE panel_id = ?", + (row["panel_id"],)) + return _row_to_panel(row) + + def consume_by_id(self, panel_id: str) -> Optional[PendingPanel]: + with self._txn() as cur: + cur.execute( + "SELECT * FROM wx_panels WHERE panel_id = ?", + (panel_id,), + ) + row = cur.fetchone() + if row is None: + return None + cur.execute("DELETE FROM wx_panels WHERE panel_id = ?", (panel_id,)) + return _row_to_panel(row) + + def sweep_expired(self) -> int: + now = time.time() + with self._txn() as cur: + cur.execute("DELETE FROM wx_panels WHERE expires_at <= ?", (now,)) + return cur.rowcount + + def __len__(self) -> int: + return len(self.list_active()) + + # ── Reply history ───────────────────────────────────────────────── + + def write_reply(self, *, to_uid: str, to_label: Optional[str], + text: str, source: Optional[str] = None, + ts: Optional[float] = None) -> None: + with self._txn() as cur: + cur.execute( + """INSERT INTO wx_reply_history (ts, to_uid, to_label, text, source) + VALUES (?, ?, ?, ?, ?)""", + (ts if ts is not None else time.time(), + to_uid, to_label, text, source), + ) + + def recent_replies(self, n: int = 20, + *, exclude_uid: Optional[str] = None) -> list[ReplyHistoryEntry]: + """Return up to ``n`` most-recent replies, newest first. + + ``exclude_uid`` skips replies that went to a specific contact — + used by candidate generation to avoid leaking the *current* + thread's drafts back as "style examples" for itself. + """ + with self._txn() as cur: + if exclude_uid: + cur.execute( + """SELECT * FROM wx_reply_history + WHERE to_uid != ? + ORDER BY ts DESC LIMIT ?""", + (exclude_uid, n), + ) + else: + cur.execute( + """SELECT * FROM wx_reply_history + ORDER BY ts DESC LIMIT ?""", + (n,), + ) + rows = cur.fetchall() + return [ + ReplyHistoryEntry( + ts=r["ts"], to_uid=r["to_uid"], + to_label=r["to_label"], text=r["text"], source=r["source"], + ) + for r in rows + ] + + def prune_history(self, *, older_than_days: int = HISTORY_KEEP_DAYS) -> int: + cutoff = time.time() - older_than_days * 86400 + with self._txn() as cur: + cur.execute("DELETE FROM wx_reply_history WHERE ts < ?", (cutoff,)) + return cur.rowcount + + +# ── In-memory fallback (no SQLite) ──────────────────────────────────────── + + +class InMemoryStore: + """Thread-safe in-memory store with the same shape as :class:`SqliteStore`. + + Used when SQLite init fails (read-only filesystem, broken db file, + permissions). All methods match the persistent store's signatures. + """ + + def __init__(self, *, timeout_s: float = DEFAULT_TIMEOUT_S) -> None: + self._panels: dict[str, PendingPanel] = {} # panel_id → panel + self._history: list[ReplyHistoryEntry] = [] + self._lock = threading.Lock() + self._timeout_s = timeout_s + self._stop = threading.Event() + self._janitor: Optional[threading.Thread] = None + self._counter = 0 + + def start_janitor(self) -> None: + if self._janitor is not None: + return + self._janitor = threading.Thread( + target=self._janitor_loop, name="wx-smart-reply-janitor-mem", + daemon=True, + ) + self._janitor.start() + + def stop(self) -> None: + self._stop.set() + if self._janitor is not None: + self._janitor.join(timeout=2.0) + + def _janitor_loop(self) -> None: + while not self._stop.wait(JANITOR_TICK_S): + self.sweep_expired() + + # Panel ID generation + def assign_next_id(self) -> str: + with self._lock: + cid = n_to_id(self._counter) + self._counter += 1 + return cid + + # Panel operations + def put(self, panel: PendingPanel) -> None: + with self._lock: + self._panels[panel.panel_id] = panel + + def take_active(self) -> Optional[PendingPanel]: + now = time.time() + with self._lock: + active = [p for p in self._panels.values() if p.expires_at > now] + if not active: + return None + active.sort(key=lambda p: p.created_at, reverse=True) + return active[0] + + def get_by_id(self, panel_id: str) -> Optional[PendingPanel]: + with self._lock: + p = self._panels.get(panel_id) + return p if (p and p.expires_at > time.time()) else None + + def list_active(self) -> list[PendingPanel]: + now = time.time() + with self._lock: + return sorted([p for p in self._panels.values() if p.expires_at > now], + key=lambda p: p.created_at) + + def consume(self, target_uid: str) -> Optional[PendingPanel]: + now = time.time() + with self._lock: + cands = [p for p in self._panels.values() + if p.target_uid == target_uid and p.expires_at > now] + if not cands: + return None + cands.sort(key=lambda p: p.created_at, reverse=True) + best = cands[0] + del self._panels[best.panel_id] + return best + + def consume_by_id(self, panel_id: str) -> Optional[PendingPanel]: + with self._lock: + return self._panels.pop(panel_id, None) + + def sweep_expired(self) -> int: + now = time.time() + with self._lock: + expired = [pid for pid, p in self._panels.items() + if p.expires_at <= now] + for pid in expired: + del self._panels[pid] + return len(expired) + + def __len__(self) -> int: + return len(self.list_active()) + + def write_reply(self, *, to_uid: str, to_label: Optional[str], + text: str, source: Optional[str] = None, + ts: Optional[float] = None) -> None: + with self._lock: + self._history.append(ReplyHistoryEntry( + ts=ts if ts is not None else time.time(), + to_uid=to_uid, to_label=to_label, text=text, source=source, + )) + # Cap to last 1000 to bound memory + if len(self._history) > 1000: + self._history = self._history[-1000:] + + def recent_replies(self, n: int = 20, + *, exclude_uid: Optional[str] = None) -> list[ReplyHistoryEntry]: + with self._lock: + rows = self._history + if exclude_uid: + rows = [r for r in rows if r.to_uid != exclude_uid] + return list(reversed(rows))[:n] + + def prune_history(self, *, older_than_days: int = HISTORY_KEEP_DAYS) -> int: + cutoff = time.time() - older_than_days * 86400 + with self._lock: + before = len(self._history) + self._history = [r for r in self._history if r.ts >= cutoff] + return before - len(self._history) + + +# ── Store factory ────────────────────────────────────────────────────────── + + +def make_store(*, db_path: Optional[Path] = None, + timeout_s: float = DEFAULT_TIMEOUT_S, + prefer_sqlite: bool = True): + """Build a store: SQLite by default, fall back to in-memory on failure.""" + if not prefer_sqlite: + return InMemoryStore(timeout_s=timeout_s) + try: + return SqliteStore(db_path, timeout_s=timeout_s) + except (sqlite3.Error, OSError): + return InMemoryStore(timeout_s=timeout_s) + + +# ── Helpers ─────────────────────────────────────────────────────────────── + + +def _row_to_panel(row) -> PendingPanel: + return PendingPanel( + panel_id=row["panel_id"], + target_uid=row["target_uid"], + target_label=row["target_label"], + message=row["message"], + candidates=json.loads(row["candidates"]), + created_at=row["created_at"], + expires_at=row["expires_at"], + ) + + +# ── Contacts (JSON-backed; user-edited source of truth) ─────────────────── + + +class ContactsStore: + """Thin loader/saver for ``~/.cheetahclaws/wx_contacts.json``. + + Schema:: + + { + "wxid_alice": { + "label": "Alice (大学同学)", + "relationship": "close friend", + "notes": "她最近在找工作。语气随便,喜欢用 emoji。" + }, + ... + } + + Missing or unreadable file → empty store; lookups return None so + callers can short-circuit cleanly. + """ + + def __init__(self, path: Optional[Path] = None) -> None: + self.path = Path(path) if path is not None else DEFAULT_CONTACTS_JSON + self._lock = threading.Lock() + self._data: dict[str, dict] = {} + self._mtime: float = 0.0 + self._load() + + def _load(self) -> None: + try: + stat = self.path.stat() + mtime = stat.st_mtime + except (FileNotFoundError, OSError): + self._data = {} + self._mtime = 0.0 + return + if mtime == self._mtime: + return + try: + self._data = json.loads(self.path.read_text(encoding="utf-8")) + if not isinstance(self._data, dict): + self._data = {} + except (json.JSONDecodeError, OSError): + self._data = {} + self._mtime = mtime + + def get(self, uid: str) -> Optional[Contact]: + """Return the contact for ``uid``, reloading the file if mtime changed.""" + with self._lock: + self._load() + entry = self._data.get(uid) + if not entry: + return None + return Contact( + uid=uid, + label=entry.get("label"), + relationship=entry.get("relationship"), + notes=entry.get("notes"), + ) + + def all(self) -> dict[str, Contact]: + with self._lock: + self._load() + return { + uid: Contact( + uid=uid, + label=v.get("label"), + relationship=v.get("relationship"), + notes=v.get("notes"), + ) + for uid, v in self._data.items() + } + + def set(self, contact: Contact) -> None: + with self._lock: + self._load() + self._data[contact.uid] = { + k: v for k, v in { + "label": contact.label, + "relationship": contact.relationship, + "notes": contact.notes, + }.items() if v is not None + } + self._save() + + def delete(self, uid: str) -> bool: + with self._lock: + self._load() + existed = self._data.pop(uid, None) is not None + if existed: + self._save() + return existed + + def _save(self) -> None: + self.path.parent.mkdir(parents=True, exist_ok=True) + tmp = self.path.with_suffix(self.path.suffix + ".tmp") + tmp.write_text(json.dumps(self._data, indent=2, ensure_ascii=False), + encoding="utf-8") + tmp.replace(self.path) + try: + stat = self.path.stat() + self._mtime = stat.st_mtime + except OSError: + pass diff --git a/cc_config.py b/cc_config.py index 51bcca7..e50cb1a 100644 --- a/cc_config.py +++ b/cc_config.py @@ -68,6 +68,19 @@ # "qwen_api_key": "..." # "zhipu_api_key": "..." # "deepseek_api_key": "..." + # ── WeChat smart-reply (off by default) ──────────────────────────────── + # When enabled, inbound messages from whitelisted contacts no longer + # auto-reply via the agent. Instead the auxiliary cheap model drafts + # 3 candidate replies and pushes them to the user's `filehelper` + # (文件传输助手) chat for approval. See bridges/wechat_smart_reply.py. + "wechat_smart_reply": False, + "wechat_smart_reply_whitelist": [], # list of from_user_id strings + "wechat_smart_reply_groups": False, # also draft for group messages + "wechat_smart_reply_groups_at_only": False, # in groups, only when @ + "wechat_smart_reply_timeout_s": 300, # panel expiry seconds + # WeChat self nickname — needed for groups_at_only matching. Not set + # automatically; user provides via config or `/wechat self `. + "wechat_self_nickname": "", } diff --git a/tests/test_wechat_smart_reply.py b/tests/test_wechat_smart_reply.py new file mode 100644 index 0000000..2b03029 --- /dev/null +++ b/tests/test_wechat_smart_reply.py @@ -0,0 +1,840 @@ +"""Unit tests for bridges/wechat_smart_reply.py + wechat_smart_reply_store.py. + +We don't exercise the full WeChat poll loop here — the iLink protocol +needs a real account to drive end-to-end. These tests cover the logic +that's testable in isolation: gating rules, store backends, parsing, +candidate extraction, prompt construction, and the high-level entry +points with stubs. +""" +from __future__ import annotations + +import json +import os +import sys +import time +from pathlib import Path +from typing import Optional + +import pytest + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from bridges import wechat_smart_reply as sr +from bridges import wechat_smart_reply_store as srs + + +# ── Identity heuristics ──────────────────────────────────────────────────── + +def test_is_filehelper(): + assert sr.is_filehelper("filehelper") is True + assert sr.is_filehelper("wxid_alice") is False + assert sr.is_filehelper("group@chatroom") is False + + +def test_is_group(): + assert sr.is_group("12345@chatroom") is True + assert sr.is_group("wxid_alice") is False + assert sr.is_group("filehelper") is False + + +# ── Gating: is_smart_reply_target ────────────────────────────────────────── + +def test_target_off_by_default(): + assert sr.is_smart_reply_target("wxid_alice", {}) is False + + +def test_target_filehelper_excluded(): + cfg = {"wechat_smart_reply": True} + assert sr.is_smart_reply_target("filehelper", cfg) is False + + +def test_target_group_excluded_by_default(): + cfg = {"wechat_smart_reply": True} + assert sr.is_smart_reply_target("12345@chatroom", cfg) is False + + +def test_target_group_included_when_groups_on(): + cfg = {"wechat_smart_reply": True, "wechat_smart_reply_groups": True} + assert sr.is_smart_reply_target("12345@chatroom", cfg, text="anything") is True + + +def test_target_whitelist_includes(): + cfg = { + "wechat_smart_reply": True, + "wechat_smart_reply_whitelist": ["wxid_alice", "wxid_bob"], + } + assert sr.is_smart_reply_target("wxid_alice", cfg) is True + assert sr.is_smart_reply_target("wxid_carol", cfg) is False + + +def test_target_empty_whitelist_means_everyone(): + cfg = {"wechat_smart_reply": True} + assert sr.is_smart_reply_target("wxid_random", cfg) is True + + +# ── Group @-only rule ────────────────────────────────────────────────────── + +def test_group_at_only_blocks_message_without_at(): + cfg = { + "wechat_smart_reply": True, + "wechat_smart_reply_groups": True, + "wechat_smart_reply_groups_at_only": True, + "wechat_self_nickname": "李明", + } + assert sr.is_smart_reply_target( + "g@chatroom", cfg, text="今天天气不错") is False + + +def test_group_at_only_allows_message_with_at(): + cfg = { + "wechat_smart_reply": True, + "wechat_smart_reply_groups": True, + "wechat_smart_reply_groups_at_only": True, + "wechat_self_nickname": "李明", + } + assert sr.is_smart_reply_target( + "g@chatroom", cfg, text="@李明 帮我看下这个") is True + + +def test_group_at_only_eos_boundary(): + cfg = { + "wechat_smart_reply": True, + "wechat_smart_reply_groups": True, + "wechat_smart_reply_groups_at_only": True, + "wechat_self_nickname": "李明", + } + # @nickname at end of string with no trailing char should still match + assert sr.is_smart_reply_target("g@chatroom", cfg, text="hi @李明") is True + + +def test_group_at_only_substring_does_not_match(): + cfg = { + "wechat_smart_reply": True, + "wechat_smart_reply_groups": True, + "wechat_smart_reply_groups_at_only": True, + "wechat_self_nickname": "李", + } + # @李明 contains "李" but should not match nickname "李" because of + # boundary rule (the next char after "李" is "明", not space/punct). + assert sr.is_smart_reply_target("g@chatroom", cfg, text="@李明 hi") is False + + +def test_group_at_only_no_nickname_set_blocks_all(): + cfg = { + "wechat_smart_reply": True, + "wechat_smart_reply_groups": True, + "wechat_smart_reply_groups_at_only": True, + "wechat_self_nickname": "", + } + assert sr.is_smart_reply_target("g@chatroom", cfg, text="@anyone hi") is False + + +# ── Panel-ID generation ──────────────────────────────────────────────────── + +def test_n_to_id_first_few(): + assert srs.n_to_id(0) == "AA" + assert srs.n_to_id(1) == "AB" + assert srs.n_to_id(25) == "AZ" + assert srs.n_to_id(26) == "BA" + + +def test_n_to_id_wraps(): + assert srs.n_to_id(26 * 26) == "AA" # wraps after 676 + + +# ── make_panel ───────────────────────────────────────────────────────────── + +def _mk_panel(uid: str, *, panel_id: str = "AA", + expires_in: float = 60) -> sr.PendingPanel: + return sr.make_panel(uid, "Alice", "hi", ["a", "b", "c"], + panel_id=panel_id, timeout_s=expires_in) + + +# ── In-memory store ─────────────────────────────────────────────────────── + +def test_store_assign_next_id_monotonic(): + store = srs.InMemoryStore() + assert store.assign_next_id() == "AA" + assert store.assign_next_id() == "AB" + assert store.assign_next_id() == "AC" + + +def test_store_put_and_take_active(): + store = srs.InMemoryStore() + p = _mk_panel("wxid_alice") + store.put(p) + assert len(store) == 1 + got = store.take_active() + assert got is not None + assert got.target_uid == "wxid_alice" + + +def test_store_take_returns_most_recent(): + store = srs.InMemoryStore() + p1 = _mk_panel("wxid_alice", panel_id="AA") + time.sleep(0.01) + p2 = _mk_panel("wxid_bob", panel_id="AB") + store.put(p1) + store.put(p2) + got = store.take_active() + assert got.target_uid == "wxid_bob" + + +def test_store_skips_expired(): + store = srs.InMemoryStore() + store.put(_mk_panel("wxid_alice", expires_in=-1)) # already expired + assert store.take_active() is None + + +def test_store_consume_by_uid(): + store = srs.InMemoryStore() + store.put(_mk_panel("wxid_alice")) + assert store.consume("wxid_alice") is not None + assert store.consume("wxid_alice") is None # idempotent + assert len(store) == 0 + + +def test_store_get_by_id(): + store = srs.InMemoryStore() + store.put(_mk_panel("wxid_alice", panel_id="AA")) + p = store.get_by_id("AA") + assert p is not None + assert p.target_uid == "wxid_alice" + assert store.get_by_id("ZZ") is None + + +def test_store_consume_by_id(): + store = srs.InMemoryStore() + store.put(_mk_panel("wxid_alice", panel_id="AA")) + assert store.consume_by_id("AA") is not None + assert store.consume_by_id("AA") is None # already consumed + + +def test_store_list_active_returns_oldest_first(): + store = srs.InMemoryStore() + store.put(_mk_panel("u1", panel_id="AA")) + time.sleep(0.01) + store.put(_mk_panel("u2", panel_id="AB")) + time.sleep(0.01) + store.put(_mk_panel("u3", panel_id="AC")) + listed = store.list_active() + assert [p.panel_id for p in listed] == ["AA", "AB", "AC"] + + +def test_store_sweep_expired_returns_count(): + store = srs.InMemoryStore() + store.put(_mk_panel("u1", panel_id="AA", expires_in=-1)) + store.put(_mk_panel("u2", panel_id="AB", expires_in=60)) + swept = store.sweep_expired() + assert swept == 1 + assert store.list_active()[0].panel_id == "AB" + + +def test_store_history_write_and_recent(): + store = srs.InMemoryStore() + store.write_reply(to_uid="u1", to_label="A", text="hi", source="candidate_1") + store.write_reply(to_uid="u2", to_label="B", text="bye", source="freeform") + rows = store.recent_replies(n=10) + assert len(rows) == 2 + assert rows[0].text == "bye" # newest first + + +def test_store_history_excludes_uid(): + store = srs.InMemoryStore() + store.write_reply(to_uid="u1", to_label="A", text="hi") + store.write_reply(to_uid="u2", to_label="B", text="bye") + rows = store.recent_replies(n=10, exclude_uid="u1") + assert [r.to_uid for r in rows] == ["u2"] + + +def test_store_history_pruning(): + store = srs.InMemoryStore() + old_ts = time.time() - (40 * 86400) + store.write_reply(to_uid="u", to_label="A", text="ancient", ts=old_ts) + store.write_reply(to_uid="u", to_label="A", text="recent") + pruned = store.prune_history(older_than_days=30) + assert pruned == 1 + assert [r.text for r in store.recent_replies(n=5)] == ["recent"] + + +# ── SQLite store ────────────────────────────────────────────────────────── + + +def test_sqlite_store_schema_initializes(tmp_path): + store = srs.SqliteStore(tmp_path / "wx.db") + try: + # Schema should be created idempotently — second open works fine. + store2 = srs.SqliteStore(tmp_path / "wx.db") + store2.stop() + finally: + store.stop() + + +def test_sqlite_store_panel_roundtrip(tmp_path): + store = srs.SqliteStore(tmp_path / "wx.db") + try: + pid = store.assign_next_id() + p = sr.make_panel("u1", "Alice", "hi", ["A", "B", "C"], panel_id=pid) + store.put(p) + got = store.get_by_id(pid) + assert got is not None + assert got.candidates == ["A", "B", "C"] + assert got.target_label == "Alice" + finally: + store.stop() + + +def test_sqlite_store_persists_across_reopen(tmp_path): + db = tmp_path / "wx.db" + store1 = srs.SqliteStore(db) + pid = store1.assign_next_id() + store1.put(sr.make_panel("u1", "Alice", "hi", ["a", "b", "c"], panel_id=pid)) + store1.stop() + + store2 = srs.SqliteStore(db) + try: + assert store2.get_by_id(pid) is not None + # ID counter persists too + assert store2.assign_next_id() == "AB" + finally: + store2.stop() + + +def test_sqlite_store_history_persists(tmp_path): + db = tmp_path / "wx.db" + s1 = srs.SqliteStore(db) + s1.write_reply(to_uid="u1", to_label="Alice", text="好的", source="candidate_1") + s1.stop() + s2 = srs.SqliteStore(db) + try: + rows = s2.recent_replies(n=5) + assert len(rows) == 1 + assert rows[0].text == "好的" + finally: + s2.stop() + + +def test_make_store_falls_back_to_memory_when_sqlite_blocked(tmp_path, + monkeypatch): + import sqlite3 as _sqlite3 + real_connect = _sqlite3.connect + + def boom(*args, **kwargs): + raise _sqlite3.OperationalError("simulated") + + monkeypatch.setattr(_sqlite3, "connect", boom) + store = srs.make_store(db_path=tmp_path / "wx.db") + monkeypatch.setattr(_sqlite3, "connect", real_connect) + assert isinstance(store, srs.InMemoryStore) + + +# ── ParsedAction parsing ───────────────────────────────────────────────── + + +def test_parse_no_panel_returns_noop(): + store = srs.InMemoryStore() + assert sr.parse_filehelper_input("1", store).kind == "noop" + + +def test_parse_queue_command(): + store = srs.InMemoryStore() + assert sr.parse_filehelper_input("q", store).kind == "list" + assert sr.parse_filehelper_input("queue", store).kind == "list" + assert sr.parse_filehelper_input("队列", store).kind == "list" + + +def test_parse_numeric_choice_uses_latest_active(): + store = srs.InMemoryStore() + pid = store.assign_next_id() + store.put(sr.make_panel("u1", "Alice", "hi", + ["A", "B", "C"], panel_id=pid)) + a = sr.parse_filehelper_input("2", store) + assert a.kind == "send" + assert a.text == "B" + assert a.panel_id == pid + + +def test_parse_skip_uses_latest_active(): + store = srs.InMemoryStore() + pid = store.assign_next_id() + store.put(sr.make_panel("u1", "Alice", "hi", + ["A", "B", "C"], panel_id=pid)) + for tok in ("x", "X", "skip", "/skip", "跳过"): + a = sr.parse_filehelper_input(tok, store) + assert a.kind == "skip" + assert a.panel_id == pid + + +def test_parse_freeform_uses_latest_active(): + store = srs.InMemoryStore() + pid = store.assign_next_id() + store.put(sr.make_panel("u1", "Alice", "hi", + ["A", "B", "C"], panel_id=pid)) + a = sr.parse_filehelper_input("我自己写的", store) + assert a.kind == "send" + assert a.text == "我自己写的" + + +def test_parse_explicit_panel_id_addressing(): + store = srs.InMemoryStore() + p1 = store.assign_next_id() # AA + p2 = store.assign_next_id() # AB + store.put(sr.make_panel("u1", "Alice", "hi", ["A", "B", "C"], panel_id=p1)) + store.put(sr.make_panel("u2", "Bob", "hi", ["X", "Y", "Z"], panel_id=p2)) + # latest is u2 (AB); but explicit AA addresses Alice's panel + a = sr.parse_filehelper_input("AA 1", store) + assert a.kind == "send" + assert a.panel_id == "AA" + assert a.text == "A" + + +def test_parse_panel_id_alone_lists_panel(): + store = srs.InMemoryStore() + p1 = store.assign_next_id() + store.put(sr.make_panel("u1", "Alice", "hi", ["A", "B", "C"], panel_id=p1)) + a = sr.parse_filehelper_input("AA", store) + assert a.kind == "list" + assert a.panel_id == "AA" + + +def test_parse_unknown_panel_id_returns_noop(): + store = srs.InMemoryStore() + p1 = store.assign_next_id() + store.put(sr.make_panel("u1", "Alice", "hi", ["A", "B", "C"], panel_id=p1)) + a = sr.parse_filehelper_input("ZZ 1", store) + # ZZ doesn't exist → _classify_choice with panel=None → noop, panel_id="ZZ" + assert a.kind == "noop" + assert a.panel_id == "ZZ" + + +# ── Candidate extraction ────────────────────────────────────────────────── + +def test_parse_candidates_clean_format(): + text = "1. 好的\n2. 周末出差\n3. 在忙稍后" + out = sr._parse_candidates(text) + assert out == ["好的", "周末出差", "在忙稍后"] + + +def test_parse_candidates_strips_trailing_punct(): + text = "1. 好的。\n2. 不行!\n3. 在忙呢…" + out = sr._parse_candidates(text) + assert out[0] == "好的" + assert out[1] == "不行" + + +def test_parse_candidates_handles_chinese_dot(): + text = "1、好\n2、行\n3、不" + out = sr._parse_candidates(text) + assert out == ["好", "行", "不"] + + +def test_parse_candidates_caps_at_three(): + text = "1. a\n2. b\n3. c\n4. d\n5. e" + out = sr._parse_candidates(text) + assert out == ["a", "b", "c"] + + +def test_parse_candidates_empty_on_garbage(): + assert sr._parse_candidates("just some prose, no numbered list") == [] + + +# ── Prompt construction (style + contact context) ───────────────────────── + +def test_build_prompt_without_extras(): + prompt = sr._build_prompt("hi", "Alice", contact=None, history=[]) + assert "hi" in prompt + assert "Alice" in prompt + assert "关于发件人" not in prompt + assert "用户最近发出" not in prompt + + +def test_build_prompt_includes_contact_relationship(): + contact = sr.Contact(uid="u", label="Alice", relationship="close friend", + notes="她在找工作") + prompt = sr._build_prompt("hi", "Alice", contact=contact, history=[]) + assert "close friend" in prompt + assert "她在找工作" in prompt + + +def test_build_prompt_includes_history_examples(): + history = [ + sr.ReplyHistoryEntry(ts=time.time(), to_uid="u1", to_label="A", + text="哈哈好的", source="candidate_1"), + sr.ReplyHistoryEntry(ts=time.time(), to_uid="u2", to_label="B", + text="晚点回你", source="freeform"), + ] + prompt = sr._build_prompt("hi", "Alice", contact=None, history=history) + assert "哈哈好的" in prompt + assert "晚点回你" in prompt + assert "用户最近发出" in prompt + + +def test_build_prompt_history_capped_at_ten(): + history = [ + sr.ReplyHistoryEntry(ts=time.time(), to_uid="u", to_label=None, + text=f"reply {i}", source=None) + for i in range(50) + ] + prompt = sr._build_prompt("hi", "Alice", contact=None, history=history) + assert "reply 0" in prompt + # 11+ should not appear (we cap at 10) + assert "reply 15" not in prompt + + +# ── generate_candidates with stub ────────────────────────────────────────── + +class _TextChunk: + def __init__(self, text: str): + self.text = text + + +def _stub_stream_3(**kwargs): + yield _TextChunk("1. 行\n2. 在忙\n3. 晚点回") + + +def _stub_stream_2(**kwargs): + yield _TextChunk("1. 好\n2. 在忙") + + +def _stub_stream_garbage(**kwargs): + yield _TextChunk("This is not a list.") + + +def _stub_stream_raises(**kwargs): + raise RuntimeError("boom") + yield # pragma: no cover + + +def test_generate_candidates_happy_path(): + out = sr.generate_candidates( + "周末有空吗", "张三", {"auxiliary_model": "test"}, + _stream_fn=_stub_stream_3, + ) + assert out == ["行", "在忙", "晚点回"] + + +def test_generate_candidates_partial_returned_as_is(): + out = sr.generate_candidates( + "x", "Y", {"auxiliary_model": "test"}, _stream_fn=_stub_stream_2, + ) + assert out == ["好", "在忙"] + + +def test_generate_candidates_returns_empty_on_garbage(): + out = sr.generate_candidates( + "x", "Y", {"auxiliary_model": "test"}, _stream_fn=_stub_stream_garbage, + ) + assert out == [] + + +def test_generate_candidates_returns_empty_on_exception(): + out = sr.generate_candidates( + "x", "Y", {"auxiliary_model": "test"}, _stream_fn=_stub_stream_raises, + ) + assert out == [] + + +def test_generate_candidates_threads_contact_into_prompt(): + captured = {} + + def capture_stream(**kwargs): + captured["msg"] = kwargs["messages"][0]["content"] + yield _TextChunk("1. ok\n2. fine\n3. sure") + + contact = sr.Contact(uid="u", label="Alice", + relationship="ex-coworker", notes="formal tone") + sr.generate_candidates( + "hi", "Alice", {"auxiliary_model": "test"}, + contact=contact, _stream_fn=capture_stream, + ) + assert "ex-coworker" in captured["msg"] + assert "formal tone" in captured["msg"] + + +# ── format_panel + format_queue ─────────────────────────────────────────── + +def test_format_panel_includes_id_label_message(): + p = sr.make_panel("wxid_alice", "Alice", "你好", + ["a", "b", "c"], panel_id="AA") + out = sr.format_panel(p) + assert "[AA]" in out + assert "Alice" in out + assert "你好" in out + assert "[1] a" in out + assert "q 看队列" in out + + +def test_format_queue_empty(): + out = sr.format_queue([]) + assert "没有" in out + + +def test_format_queue_lists_with_ids(): + p1 = sr.make_panel("u1", "Alice", "msg1", ["a", "b", "c"], panel_id="AA") + p2 = sr.make_panel("u2", "Bob", "msg2", ["x", "y", "z"], panel_id="AB") + out = sr.format_queue([p1, p2]) + assert "[AA]" in out + assert "[AB]" in out + assert "Alice" in out + assert "Bob" in out + + +def test_format_panel_truncates_long_message(): + long_msg = "x" * 300 + p = sr.make_panel("u", "U", long_msg, ["a", "b", "c"], panel_id="AA") + out = sr.format_panel(p) + assert "…" in out + + +# ── trigger_smart_reply (end-to-end with stubs) ────────────────────────── + +def test_trigger_happy_path_with_in_memory_store(): + sent: list = [] + store = srs.InMemoryStore() + cfg = {"auxiliary_model": "test"} + ok = sr.trigger_smart_reply( + target_uid="wxid_alice", target_label="Alice", + message="周末有空", store=store, config=cfg, + send_to_filehelper=lambda txt: sent.append(txt), + generate_fn=lambda *a, **k: ["行", "忙", "晚点回"], + ) + assert ok is True + assert len(sent) == 1 + assert "[AA]" in sent[0] # first panel gets first ID + assert "Alice" in sent[0] + assert len(store) == 1 + + +def test_trigger_pads_short_list_to_three(): + store = srs.InMemoryStore() + sent: list = [] + sr.trigger_smart_reply( + target_uid="u", target_label="U", + message="x", store=store, config={}, + send_to_filehelper=lambda txt: sent.append(txt), + generate_fn=lambda *a, **k: ["哈哈"], + ) + panel = store.take_active() + assert len(panel.candidates) == 3 + + +def test_trigger_returns_false_on_empty_generation(): + store = srs.InMemoryStore() + sent: list = [] + ok = sr.trigger_smart_reply( + target_uid="u", target_label="U", message="x", + store=store, config={}, + send_to_filehelper=lambda txt: sent.append(txt), + generate_fn=lambda *a, **k: [], + ) + assert ok is False + assert sent == [] + assert len(store) == 0 + + +def test_trigger_passes_contact_into_generator(tmp_path): + captured = {} + + def fake_gen(message, label, config, *, contact=None, history=None): + captured["contact"] = contact + captured["history"] = history + return ["a", "b", "c"] + + contacts_path = tmp_path / "wx_contacts.json" + contacts_path.write_text(json.dumps({ + "u": {"label": "Alice", "relationship": "friend"} + })) + contacts = sr.ContactsStore(path=contacts_path) + + store = srs.InMemoryStore() + sr.trigger_smart_reply( + target_uid="u", target_label="U", message="x", + store=store, config={}, + send_to_filehelper=lambda txt: None, + contacts=contacts, + generate_fn=fake_gen, + ) + assert captured["contact"] is not None + assert captured["contact"].relationship == "friend" + + +# ── handle_filehelper_message ───────────────────────────────────────────── + + +def test_handle_no_active_panel_falls_through(): + store = srs.InMemoryStore() + consumed = sr.handle_filehelper_message( + "anything", store, + send_to_target=lambda u, t: None, + send_to_filehelper=lambda t: None, + ) + assert consumed is False + + +def test_handle_numeric_sends_candidate_and_records_history(): + store = srs.InMemoryStore() + pid = store.assign_next_id() + store.put(sr.make_panel("wxid_alice", "Alice", "你好", + ["A", "B", "C"], panel_id=pid)) + sent_target: list = [] + sent_fh: list = [] + consumed = sr.handle_filehelper_message( + "2", store, + send_to_target=lambda u, t: sent_target.append((u, t)), + send_to_filehelper=lambda t: sent_fh.append(t), + ) + assert consumed is True + assert sent_target == [("wxid_alice", "B")] + history = store.recent_replies(n=5) + assert len(history) == 1 + assert history[0].text == "B" + assert history[0].source == "candidate_2" + + +def test_handle_skip_drops_panel_no_history(): + store = srs.InMemoryStore() + pid = store.assign_next_id() + store.put(sr.make_panel("wxid_alice", "Alice", "你好", + ["A", "B", "C"], panel_id=pid)) + sent_target: list = [] + consumed = sr.handle_filehelper_message( + "x", store, + send_to_target=lambda u, t: sent_target.append((u, t)), + send_to_filehelper=lambda t: None, + ) + assert consumed is True + assert sent_target == [] + assert store.recent_replies(n=5) == [] # skip ≠ send → no history row + + +def test_handle_freeform_records_as_freeform(): + store = srs.InMemoryStore() + pid = store.assign_next_id() + store.put(sr.make_panel("u", "U", "hi", ["A", "B", "C"], panel_id=pid)) + sr.handle_filehelper_message( + "我自己写的", store, + send_to_target=lambda u, t: None, + send_to_filehelper=lambda t: None, + ) + history = store.recent_replies(n=5) + assert history[0].source == "freeform" + assert history[0].text == "我自己写的" + + +def test_handle_queue_command_lists_pending(): + store = srs.InMemoryStore() + p1 = store.assign_next_id() + p2 = store.assign_next_id() + store.put(sr.make_panel("u1", "Alice", "hi", ["a", "b", "c"], panel_id=p1)) + store.put(sr.make_panel("u2", "Bob", "hi", ["x", "y", "z"], panel_id=p2)) + sent_fh: list = [] + consumed = sr.handle_filehelper_message( + "q", store, + send_to_target=lambda u, t: None, + send_to_filehelper=lambda t: sent_fh.append(t), + ) + assert consumed is True + assert any("Alice" in s and "Bob" in s for s in sent_fh) + + +def test_handle_explicit_panel_id_addressing(): + store = srs.InMemoryStore() + p1 = store.assign_next_id() # AA + p2 = store.assign_next_id() # AB + store.put(sr.make_panel("u1", "Alice", "hi", ["A", "B", "C"], panel_id=p1)) + store.put(sr.make_panel("u2", "Bob", "hi", ["X", "Y", "Z"], panel_id=p2)) + sent_target: list = [] + sr.handle_filehelper_message( + "AA 1", store, + send_to_target=lambda u, t: sent_target.append((u, t)), + send_to_filehelper=lambda t: None, + ) + # AA = Alice's panel; "1" = first candidate "A" + assert sent_target == [("u1", "A")] + # Bob's panel still pending + assert store.get_by_id("AB") is not None + + +def test_handle_unknown_panel_id_returns_warning(): + store = srs.InMemoryStore() + pid = store.assign_next_id() + store.put(sr.make_panel("u1", "Alice", "hi", ["A", "B", "C"], panel_id=pid)) + sent_fh: list = [] + consumed = sr.handle_filehelper_message( + "ZZ 1", store, + send_to_target=lambda u, t: None, + send_to_filehelper=lambda t: sent_fh.append(t), + ) + # ZZ doesn't exist; parse_filehelper_input returns kind=noop with panel_id="ZZ" + # but handle_filehelper_message should fall through (no known panel). + assert consumed is False + + +# ── ContactsStore ───────────────────────────────────────────────────────── + + +def test_contacts_missing_file_returns_none(tmp_path): + store = sr.ContactsStore(path=tmp_path / "nonexistent.json") + assert store.get("anyone") is None + assert store.all() == {} + + +def test_contacts_set_and_get_roundtrip(tmp_path): + p = tmp_path / "wx_contacts.json" + store = sr.ContactsStore(path=p) + store.set(sr.Contact(uid="wxid_alice", label="Alice (friend)", + relationship="close friend", notes="loves coffee")) + # File written + assert p.exists() + data = json.loads(p.read_text(encoding="utf-8")) + assert data["wxid_alice"]["relationship"] == "close friend" + # Round-trip via fresh store + fresh = sr.ContactsStore(path=p) + got = fresh.get("wxid_alice") + assert got is not None + assert got.notes == "loves coffee" + + +def test_contacts_delete(tmp_path): + p = tmp_path / "wx_contacts.json" + store = sr.ContactsStore(path=p) + store.set(sr.Contact(uid="u", label="L")) + assert store.delete("u") is True + assert store.delete("u") is False # idempotent + assert store.get("u") is None + + +def test_contacts_corrupt_file_returns_empty(tmp_path): + p = tmp_path / "wx_contacts.json" + p.write_text("not valid json") + store = sr.ContactsStore(path=p) + assert store.all() == {} + assert store.get("any") is None + + +def test_contacts_reload_on_mtime_change(tmp_path): + p = tmp_path / "wx_contacts.json" + p.write_text(json.dumps({"u": {"label": "First"}})) + store = sr.ContactsStore(path=p) + assert store.get("u").label == "First" + + # Edit file out-of-band then reload + time.sleep(0.05) + p.write_text(json.dumps({"u": {"label": "Second"}})) + # Force mtime change (some filesystems have low resolution) + new_mtime = p.stat().st_mtime + 1 + os.utime(p, (new_mtime, new_mtime)) + assert store.get("u").label == "Second" + + +# ── Config defaults ─────────────────────────────────────────────────────── + + +def test_cc_config_defaults_present(): + from cc_config import DEFAULTS + assert DEFAULTS["wechat_smart_reply"] is False + assert DEFAULTS["wechat_smart_reply_whitelist"] == [] + assert DEFAULTS["wechat_smart_reply_groups"] is False + assert DEFAULTS["wechat_smart_reply_groups_at_only"] is False + assert DEFAULTS["wechat_smart_reply_timeout_s"] == 300 + assert DEFAULTS["wechat_self_nickname"] == ""