Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 116 additions & 2 deletions crates/core/src/convo_miner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::normalize::normalize;
use crate::palace_db::PalaceDb;
use chrono::Utc;
use sha2::{Digest, Sha256};
use std::io::Write;
use std::path::{Path, PathBuf};
use walkdir::WalkDir;

Expand Down Expand Up @@ -426,7 +427,20 @@ fn generate_drawer_id(wing: &str, room: &str, source_file: &str, chunk_index: us
format!("drawer_{}_{}_{}", wing, room, &hex[..24])
}

/// Scan `convo_dir` for conversation files, returning paths to mine.
///
/// Skips symlinks (which could otherwise follow links to recursive structures
/// or `/dev/urandom`) and oversized files. Each skipped symlink is logged to
/// `stderr` with a `" SKIP: <relative-path> (symlink)"` line so callers can
/// tell why a directory looks empty after walking (#1462).
fn scan_convos(convo_dir: &Path) -> Vec<PathBuf> {
scan_convos_with_log(convo_dir, &mut std::io::stderr())
}

/// Same as [`scan_convos`] but routes the skipped-symlink diagnostic to an
/// arbitrary writer. Lets unit tests assert the log fires without having to
/// fork a subprocess to capture stderr.
fn scan_convos_with_log<W: Write>(convo_dir: &Path, skip_log: &mut W) -> Vec<PathBuf> {
let mut files = Vec::new();
for entry in WalkDir::new(convo_dir)
.follow_links(false)
Expand All @@ -437,15 +451,20 @@ fn scan_convos(convo_dir: &Path) -> Vec<PathBuf> {
})
.filter_map(|entry| entry.ok())
{
if !entry.file_type().is_file() {
let ft = entry.file_type();
// Let regular files AND symlinks through. Walkdir's `is_file()`
// returns `false` for symlinks-to-files under `follow_links(false)`,
// so a bare `!is_file()` check would silently drop every symlink
// before the diagnostic branch below can fire.
if !ft.is_file() && !ft.is_symlink() {
continue;
}
let path = entry.path();
let name = path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or_default();
if name.ends_with(".meta.json") || path.is_symlink() {
if name.ends_with(".meta.json") {
continue;
}
let extension = path
Expand All @@ -456,6 +475,20 @@ fn scan_convos(convo_dir: &Path) -> Vec<PathBuf> {
if !CONVO_EXTENSIONS.contains(&extension.as_str()) {
continue;
}
// Skip symlinks — prevents following recursive/bogus links. Log to
// `skip_log` with the path relative to the scan root so the
// diagnostic is unambiguous and renders with forward slashes on
// every platform. Runs AFTER the extension filter to match upstream
// Python `scan_convos` ordering — a `.png` symlink is silently
// dropped at the extension gate rather than logged. (#1462)
if ft.is_symlink() {
let rel = path
.strip_prefix(convo_dir)
.map(|p| p.to_string_lossy().replace('\\', "/"))
.unwrap_or_else(|_| path.to_string_lossy().to_string());
let _ = writeln!(skip_log, " SKIP: {rel} (symlink)");
continue;
}
let Ok(metadata) = path.metadata() else {
continue;
};
Expand Down Expand Up @@ -671,6 +704,87 @@ mod tests {
assert!(!names.contains(&"config.txt".to_string()));
}

#[cfg(unix)]
#[test]
fn test_scan_convos_skips_symlinks() {
// Regression for upstream #1462: scan_convos drops symlinked files
// so the walker can't recurse into bogus link targets. The stderr
// SKIP log surfaces the skip with a path relative to the scan root.
// Asserts the diagnostic actually fires — relying on result-set
// exclusion alone passes against dead-code symlink branches, which
// is how the initial port shipped.
let temp = tempfile::TempDir::new().unwrap();
let real = temp.path().join("real.md");
std::fs::write(&real, "hello world").unwrap();
std::os::unix::fs::symlink(&real, temp.path().join("link.md")).unwrap();

let mut log = Vec::new();
let files = scan_convos_with_log(temp.path(), &mut log);
let names: Vec<String> = files
.iter()
.map(|path| path.file_name().unwrap().to_string_lossy().to_string())
.collect();
assert_eq!(names, vec!["real.md".to_string()]);
let log = String::from_utf8(log).unwrap();
assert!(
log.contains(" SKIP: link.md (symlink)\n"),
"expected SKIP diagnostic for link.md, got: {log:?}"
);
}

#[cfg(unix)]
#[test]
fn test_scan_convos_skips_dangling_symlinks() {
// A dangling symlink in the convo dir must not panic the walker nor
// surface in the result set. Mirrors upstream coverage for #1462's
// polished dangling-link path.
let temp = tempfile::TempDir::new().unwrap();
std::fs::write(temp.path().join("real.md"), "hello world").unwrap();
std::os::unix::fs::symlink(
temp.path().join("missing.md"),
temp.path().join("dangling.md"),
)
.unwrap();

let mut log = Vec::new();
let files = scan_convos_with_log(temp.path(), &mut log);
let names: Vec<String> = files
.iter()
.map(|path| path.file_name().unwrap().to_string_lossy().to_string())
.collect();
assert_eq!(names, vec!["real.md".to_string()]);
let log = String::from_utf8(log).unwrap();
assert!(
log.contains(" SKIP: dangling.md (symlink)\n"),
"expected SKIP diagnostic for dangling.md, got: {log:?}"
);
}

#[cfg(unix)]
#[test]
fn test_scan_convos_does_not_log_extension_filtered_symlinks() {
// A symlink whose name doesn't match `CONVO_EXTENSIONS` is silently
// dropped at the extension gate — it must NOT surface in the SKIP
// log, matching upstream Python ordering (extension → symlink-log
// → size).
let temp = tempfile::TempDir::new().unwrap();
std::fs::write(temp.path().join("real.md"), "hello world").unwrap();
std::os::unix::fs::symlink("real.md", temp.path().join("link.png")).unwrap();

let mut log = Vec::new();
let files = scan_convos_with_log(temp.path(), &mut log);
let names: Vec<String> = files
.iter()
.map(|path| path.file_name().unwrap().to_string_lossy().to_string())
.collect();
assert_eq!(names, vec!["real.md".to_string()]);
let log = String::from_utf8(log).unwrap();
assert!(
!log.contains("link.png"),
"extension-filtered symlink leaked into SKIP log: {log:?}"
);
}

#[test]
fn test_scan_convos_skips_python_parity_dirs() {
let temp = tempfile::TempDir::new().unwrap();
Expand Down
173 changes: 170 additions & 3 deletions crates/core/src/mcp_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
//! Exposes MemPalace functionality as MCP tools via stdio transport.
//! Read-only mode restricts mutations (diary_write, config_write, people_write).

use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::fs;
use std::io::Write;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::{Arc, OnceLock};

use rmcp::model::{
CallToolResult, Content, Implementation, InitializeResult, JsonObject, ListToolsResult,
Expand Down Expand Up @@ -209,7 +209,10 @@ where
16,
);
log_tool_invocation(&tool_name, &args, None, &trace_id);
let result = dispatch(tool_name.clone(), args.clone()).await;
let result = match validate_known_params(&tool_name, &args) {
Err(err) => Err(err),
Ok(()) => dispatch(tool_name.clone(), args.clone()).await,
};
log_tool_invocation(
&tool_name,
&args,
Expand All @@ -219,6 +222,78 @@ where
result
}

/// Tools whose Input struct accepts arbitrary extra fields (via
/// `#[serde(flatten)] custom_metadata: Option<serde_json::Value>` or
/// equivalent). These bypass the unknown-parameter check so callers can keep
/// supplying custom metadata keys — mirrors Python's `accepts_var_keyword`
/// gate on `**kwargs` handlers. (#1512)
const TOOLS_ACCEPTING_EXTRAS: &[&str] = &["mempalace_add_drawer"];

/// Keys that are internal transport metadata and live in no tool schema. They
/// are stripped before dispatch elsewhere; flagging them as unknown here would
/// surface a misleading error for legitimate transport-level options. (#1512)
const TRANSPORT_RESERVED_KEYS: &[&str] = &["wait_for_previous"];

/// Lazily-built lookup of `tool_name -> declared input-schema property names`.
/// Source of truth is `make_tools()`'s JSON schema, so adding a property in
/// one place automatically updates the unknown-parameter check.
fn tool_schema_props() -> &'static HashMap<String, HashSet<String>> {
static SCHEMA_PROPS: OnceLock<HashMap<String, HashSet<String>>> = OnceLock::new();
SCHEMA_PROPS.get_or_init(|| {
let mut by_tool = HashMap::new();
for tool in make_tools() {
let value = serde_json::Value::Object((*tool.input_schema).clone());
let props = value
.get("properties")
.and_then(|p| p.as_object())
.map(|obj| obj.keys().cloned().collect::<HashSet<_>>())
.unwrap_or_default();
by_tool.insert(tool.name.to_string(), props);
}
by_tool
})
}

/// Reject unknown parameter *names* with JSON-RPC -32602 instead of letting
/// serde silently drop them and resurfacing the typo as a downstream
/// "missing required" error. Skips tools whose Input struct uses
/// `#[serde(flatten)]` extras (`TOOLS_ACCEPTING_EXTRAS`) and the
/// `wait_for_previous` transport kwarg, matching upstream Python's
/// `accepts_var_keyword` gate. (#1512)
fn validate_known_params(tool_name: &str, args: &JsonObject) -> Result<(), ErrorData> {
if TOOLS_ACCEPTING_EXTRAS.contains(&tool_name) {
return Ok(());
}
let Some(allowed) = tool_schema_props().get(tool_name) else {
// Unknown tool — let dispatch surface method_not_found.
return Ok(());
};
let mut unknown: Vec<&str> = args
.keys()
.filter(|k| !allowed.contains(k.as_str()))
.filter(|k| !TRANSPORT_RESERVED_KEYS.contains(&k.as_str()))
.map(String::as_str)
.collect();
if unknown.is_empty() {
return Ok(());
}
unknown.sort_unstable();
let quoted = unknown
.iter()
.map(|k| format!("'{k}'"))
.collect::<Vec<_>>()
.join(", ");
let word = if unknown.len() == 1 {
"parameter"
} else {
"parameters"
};
Err(ErrorData::invalid_params(
format!("Unknown {word} {quoted} for tool {tool_name}"),
None,
))
}

fn make_dispatch(state: Arc<AppState>) -> impl Fn(String, JsonObject) -> DynResult {
move |name, args| {
let state = state.clone();
Expand Down Expand Up @@ -2350,4 +2425,96 @@ mod tests {
);
assert!(result.is_ok(), "traverse failed: {:?}", result);
}

// ---------------------------------------------------------------------
// Unknown parameter name (#1512)
//
// A kwarg not in the tool schema (wrong parameter *name*, e.g. `text=`
// instead of `content=`) should surface as JSON-RPC -32602 naming the
// offending kwarg, instead of being silently dropped by serde and
// resurfacing indirectly as a later "Missing required 'X'". Symmetric
// with the missing-required-shape path. The internal `wait_for_previous`
// transport kwarg must never be flagged, and handlers whose Input struct
// uses `#[serde(flatten)]` extras (`mempalace_add_drawer`) must keep
// accepting unknown kwargs.
// ---------------------------------------------------------------------

#[test]
fn test_unknown_param_returns_invalid_params_for_wrong_kwarg_name() {
let state = test_state();
let err = dispatch(
&state,
"mempalace_search",
json!({ "query": "hello", "txt": "oops" }),
)
.expect_err("unknown 'txt' should surface as an error");
assert_eq!(err.code.0, -32602);
let message = err.message.as_ref();
assert!(message.contains("'txt'"), "message: {message}");
assert!(message.contains("Unknown parameter"), "message: {message}");
assert!(message.contains("mempalace_search"), "message: {message}");
// Names the actual wrong kwarg, not the indirect missing-required symptom.
assert!(!message.contains("Missing required"), "message: {message}");
}

#[test]
fn test_two_unknown_params_list_both_names() {
let state = test_state();
let err = dispatch(
&state,
"mempalace_search",
json!({ "query": "hello", "txt": "a", "bogus": "b" }),
)
.expect_err("multiple unknown params should error");
assert_eq!(err.code.0, -32602);
let message = err.message.as_ref();
assert!(message.contains("parameters"), "message: {message}");
assert!(message.contains("'txt'"), "message: {message}");
assert!(message.contains("'bogus'"), "message: {message}");
}

#[test]
fn test_wait_for_previous_not_flagged_as_unknown() {
// `wait_for_previous` is an internal transport kwarg in no tool
// schema; it must not trip the unknown-param check.
let state = test_state();
let result = dispatch(
&state,
"mempalace_diary_write",
json!({
"agent_name": "x",
"entry": "y",
"wait_for_previous": true,
}),
);
assert!(
result.is_ok(),
"wait_for_previous should pass through: {:?}",
result
);
}

#[test]
fn test_add_drawer_accepts_unknown_custom_metadata_keys() {
// `mempalace_add_drawer` uses `#[serde(flatten)] custom_metadata`,
// so callers may supply arbitrary string-valued metadata keys —
// those must not be rejected as unknown parameters.
let state = test_state();
let result = dispatch(
&state,
"mempalace_add_drawer",
json!({
"wing": "test",
"room": "decisions",
"content": "we picked sqlite",
"priority": "high",
"status": "open",
}),
);
assert!(
result.is_ok(),
"add_drawer custom metadata should pass through: {:?}",
result
);
}
}
Loading
Loading