diff --git a/mcp_server/README.md b/mcp_server/README.md index 67077b92..85198438 100644 --- a/mcp_server/README.md +++ b/mcp_server/README.md @@ -8,6 +8,8 @@ MCP server that exposes PaperBanana's diagram and plot generation as tools for C |------|-------------| | `generate_diagram` | Generate a methodology diagram from text context + caption | | `generate_plot` | Generate a statistical plot from JSON data + intent description | +| `continue_diagram` | Continue a prior methodology `run_*` (more refinement and/or critic feedback); returns JSON paths | +| `continue_plot` | Continue a prior statistical-plot `run_*`; same JSON contract as `continue_diagram` | | `evaluate_diagram` | Compare a generated diagram against a human reference (4 dimensions) | | `evaluate_plot` | Compare a generated statistical plot against a human reference (4 dimensions) | | `download_references` | Download the expanded reference set for stronger retrieval | @@ -23,6 +25,10 @@ Long runs execute in a **worker thread** so they do not block the MCP server eve On validation errors (missing manifest, bad flags), the JSON body includes `"error"` and `"strict_success": false`. +### Continue tools + +`continue_diagram` and `continue_plot` mirror ``paperbanana generate --continue-run`` / Studio continue: they load `run_input.json` and the latest iteration under `output_dir` / `run_id`, then run more visualizer–critic rounds. Pick the tool that matches the run’s `diagram_type` (`methodology` vs `statistical_plot`); otherwise the response is `strict_success: false` with a hint to use the other tool. Successful responses include `final_image_path`, `run_dir`, and `metadata_path` when present. + ## Installation ### Quick Install (via `uvx`) diff --git a/mcp_server/server.py b/mcp_server/server.py index d50dad08..72303d0a 100644 --- a/mcp_server/server.py +++ b/mcp_server/server.py @@ -6,6 +6,8 @@ Tools: generate_diagram — Generate a methodology diagram from text generate_plot — Generate a statistical plot from JSON data + continue_diagram — Continue a prior methodology run (more refinement / feedback) + continue_plot — Continue a prior statistical-plot run evaluate_diagram — Evaluate a generated diagram against a reference evaluate_plot — Evaluate a generated plot against a reference download_references — Download expanded reference set (~294 examples) @@ -24,6 +26,7 @@ import os from io import BytesIO from pathlib import Path +from typing import Any import structlog from fastmcp import FastMCP @@ -32,6 +35,7 @@ from paperbanana.core.config import Settings from paperbanana.core.pipeline import PaperBananaPipeline +from paperbanana.core.resume import load_resume_state from paperbanana.core.types import DiagramType, GenerationInput from paperbanana.core.utils import detect_image_mime_type, find_prompt_dir from paperbanana.core.workflow_runner import ( @@ -281,6 +285,116 @@ def _on_progress(event: str, payload: dict) -> None: return Image(path=effective_path, format=fmt) +@mcp.tool +async def continue_diagram( + run_id: str, + output_dir: str = "outputs", + feedback: str | None = None, + iterations: int | None = None, + auto_refine: bool = False, + max_iterations: int | None = None, + config: str | None = None, + vlm_provider: str | None = None, + vlm_model: str | None = None, + image_provider: str | None = None, + image_model: str | None = None, + output_format: str = "png", + optimize: bool = False, + save_prompts: bool | None = None, + venue: str | None = None, + generate_caption: bool = False, +) -> str: + """Continue a methodology diagram run under ``output_dir`` / ``run_id``. + + Loads ``run_input.json`` and the latest iteration from an existing ``run_*`` + directory (same as ``paperbanana generate --continue-run``). Runs more + visualizer–critic rounds without redoing retrieval / planner / stylist. + + Args: + run_id: Directory name (e.g. ``run_20250109_120000_abc``). + output_dir: Base output directory containing the run folder. + feedback: Optional notes for the critic (same as CLI ``--feedback``). + iterations: Extra refinement rounds when ``auto_refine`` is false; + also sets ``refinement_iterations`` in settings when provided. + When omitted, uses the configured default iteration count. + auto_refine: If true, loop until the critic is satisfied (capped by + ``max_iterations`` or settings default). + max_iterations: Cap when ``auto_refine`` is true. + config: Optional path to YAML config (same as other MCP tools). + vlm_provider, vlm_model, image_provider, image_model: Optional overrides. + output_format: png, jpeg, or webp. + optimize: Passed through to settings (normally unused for continue). + save_prompts, venue, generate_caption: Optional settings overrides. + + Returns: + JSON string with ``strict_success``, paths, and ``new_iteration_count``; + on failure, ``strict_success`` is false and ``error`` explains why. + """ + return await _continue_run_mcp( + expected=DiagramType.METHODOLOGY, + run_id=run_id, + output_dir=output_dir, + feedback=feedback, + iterations=iterations, + auto_refine=auto_refine, + max_iterations=max_iterations, + config=config, + vlm_provider=vlm_provider, + vlm_model=vlm_model, + image_provider=image_provider, + image_model=image_model, + output_format=output_format, + optimize=optimize, + save_prompts=save_prompts, + venue=venue, + generate_caption=generate_caption, + ) + + +@mcp.tool +async def continue_plot( + run_id: str, + output_dir: str = "outputs", + feedback: str | None = None, + iterations: int | None = None, + auto_refine: bool = False, + max_iterations: int | None = None, + config: str | None = None, + vlm_provider: str | None = None, + vlm_model: str | None = None, + image_provider: str | None = None, + image_model: str | None = None, + output_format: str = "png", + optimize: bool = False, + save_prompts: bool | None = None, + venue: str | None = None, + generate_caption: bool = False, +) -> str: + """Continue a statistical-plot run (same contract as ``continue_diagram``). + + Use when ``run_input.json`` has ``diagram_type`` ``statistical_plot``. + """ + return await _continue_run_mcp( + expected=DiagramType.STATISTICAL_PLOT, + run_id=run_id, + output_dir=output_dir, + feedback=feedback, + iterations=iterations, + auto_refine=auto_refine, + max_iterations=max_iterations, + config=config, + vlm_provider=vlm_provider, + vlm_model=vlm_model, + image_provider=image_provider, + image_model=image_model, + output_format=output_format, + optimize=optimize, + save_prompts=save_prompts, + venue=venue, + generate_caption=generate_caption, + ) + + @mcp.tool async def evaluate_diagram( generated_path: str, @@ -417,6 +531,157 @@ def _json_result(payload: dict) -> str: return json.dumps(payload, indent=2) +def _load_dotenv_best_effort() -> None: + try: + from dotenv import load_dotenv + + load_dotenv() + except ImportError: + pass + + +def _mcp_settings_for_continue( + output_dir: str, + config: str | None, + *, + output_format: str = "png", + iterations: int | None = None, + auto_refine: bool = False, + max_iterations: int | None = None, + optimize: bool = False, + vlm_provider: str | None = None, + vlm_model: str | None = None, + image_provider: str | None = None, + image_model: str | None = None, + save_prompts: bool | None = None, + venue: str | None = None, + generate_caption: bool = False, +) -> Settings: + """Build Settings for continue-run tools (mirrors CLI ``generate`` overrides).""" + overrides: dict[str, Any] = { + "output_dir": (output_dir or "outputs").strip() or "outputs", + "output_format": output_format.lower(), + "auto_refine": bool(auto_refine), + "optimize_inputs": bool(optimize), + "generate_caption": bool(generate_caption), + } + if iterations is not None: + overrides["refinement_iterations"] = max(1, int(iterations)) + if max_iterations is not None: + overrides["max_iterations"] = max(1, int(max_iterations)) + if vlm_provider: + overrides["vlm_provider"] = vlm_provider.strip() + if vlm_model: + overrides["vlm_model"] = vlm_model.strip() + if image_provider: + overrides["image_provider"] = image_provider.strip() + if image_model: + overrides["image_model"] = image_model.strip() + if save_prompts is not None: + overrides["save_prompts"] = save_prompts + if venue: + overrides["venue"] = venue.strip() + + cfg = (config or "").strip() + if cfg: + return Settings.from_yaml(Path(cfg).expanduser(), **overrides) + _load_dotenv_best_effort() + return Settings(**overrides) + + +async def _continue_run_mcp( + *, + expected: DiagramType, + run_id: str, + output_dir: str = "outputs", + feedback: str | None = None, + iterations: int | None = None, + auto_refine: bool = False, + max_iterations: int | None = None, + config: str | None = None, + vlm_provider: str | None = None, + vlm_model: str | None = None, + image_provider: str | None = None, + image_model: str | None = None, + output_format: str = "png", + optimize: bool = False, + save_prompts: bool | None = None, + venue: str | None = None, + generate_caption: bool = False, +) -> str: + """Shared implementation for ``continue_diagram`` / ``continue_plot``.""" + tool_name = "continue_diagram" if expected == DiagramType.METHODOLOGY else "continue_plot" + settings = _mcp_settings_for_continue( + output_dir, + config, + output_format=output_format, + iterations=iterations, + auto_refine=auto_refine, + max_iterations=max_iterations, + optimize=optimize, + vlm_provider=vlm_provider, + vlm_model=vlm_model, + image_provider=image_provider, + image_model=image_model, + save_prompts=save_prompts, + venue=venue, + generate_caption=generate_caption, + ) + try: + resume_state = load_resume_state(settings.output_dir, run_id.strip()) + except (FileNotFoundError, ValueError) as e: + return _json_result({"error": str(e), "strict_success": False}) + + if resume_state.diagram_type != expected: + other = "continue_plot" if expected == DiagramType.METHODOLOGY else "continue_diagram" + return _json_result( + { + "error": (f"This run is {resume_state.diagram_type.value}; use {other} instead."), + "strict_success": False, + "actual_diagram_type": resume_state.diagram_type.value, + } + ) + + def _on_progress(event: str, payload: dict) -> None: + logger.info( + "mcp_progress", + tool=tool_name, + progress_event=event, + **payload, + ) + + try: + pipeline = PaperBananaPipeline(settings=settings, progress_callback=_on_progress) + result = await pipeline.continue_run( + resume_state=resume_state, + additional_iterations=iterations, + user_feedback=(feedback or "").strip() or None, + progress_callback=None, + ) + except Exception as e: + logger.exception("mcp_continue_failed", tool=tool_name) + return _json_result({"error": str(e), "strict_success": False}) + + run_dir = Path(resume_state.run_dir) + meta = run_dir / "metadata.json" + payload: dict[str, Any] = { + "strict_success": True, + "run_id": resume_state.run_id, + "run_dir": str(run_dir.resolve()), + "final_image_path": str(Path(result.image_path).resolve()), + "metadata_path": str(meta.resolve()) if meta.is_file() else None, + "diagram_type": resume_state.diagram_type.value, + "new_iteration_count": len(result.iterations), + } + if result.vector_svg_path: + payload["vector_svg_path"] = str(Path(result.vector_svg_path).resolve()) + if result.vector_pdf_path: + payload["vector_pdf_path"] = str(Path(result.vector_pdf_path).resolve()) + if result.generated_caption: + payload["generated_caption"] = result.generated_caption + return _json_result(payload) + + @mcp.tool async def orchestrate_figures( paper: str | None = None, diff --git a/tests/test_utils.py b/tests/test_utils.py index 8f56f3db..b47e6d23 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json from io import BytesIO from pathlib import Path @@ -256,3 +257,129 @@ def test_uncompressible_raises(self, tmp_path: Path, monkeypatch): with pytest.raises(ValueError, match="could not be compressed"): _compress_for_api(str(p)) + + +@pytest.mark.skipif(not _has_fastmcp, reason="fastmcp not installed") +class TestMcpContinueRun: + """Tests for MCP continue_diagram / continue_plot helpers (no live API).""" + + @staticmethod + def _write_resumable_run( + tmp_path: Path, *, diagram_type: str, run_id: str = "run_mcp_test" + ) -> Path: + out = tmp_path / "outputs" + run_dir = out / run_id + run_dir.mkdir(parents=True) + (run_dir / "run_input.json").write_text( + json.dumps( + { + "source_context": "Paper describes a two-phase pipeline.", + "communicative_intent": "Overview figure.", + "diagram_type": diagram_type, + } + ), + encoding="utf-8", + ) + iter_dir = run_dir / "iter_1" + iter_dir.mkdir() + (iter_dir / "details.json").write_text( + json.dumps( + { + "description": "First draft layout", + "critique": {"revised_description": "Revised layout text"}, + } + ), + encoding="utf-8", + ) + png = run_dir / "diagram_iter_1.png" + _write_png(png) + return out + + @pytest.mark.asyncio + async def test_continue_diagram_rejects_plot_run(self, tmp_path: Path): + from mcp_server.server import _continue_run_mcp + from paperbanana.core.types import DiagramType + + out = self._write_resumable_run(tmp_path, diagram_type="statistical_plot") + raw = await _continue_run_mcp( + expected=DiagramType.METHODOLOGY, + run_id="run_mcp_test", + output_dir=str(out), + ) + data = json.loads(raw) + assert data["strict_success"] is False + assert "continue_plot" in data["error"] + + @pytest.mark.asyncio + async def test_continue_plot_rejects_methodology_run(self, tmp_path: Path): + from mcp_server.server import _continue_run_mcp + from paperbanana.core.types import DiagramType + + out = self._write_resumable_run(tmp_path, diagram_type="methodology") + raw = await _continue_run_mcp( + expected=DiagramType.STATISTICAL_PLOT, + run_id="run_mcp_test", + output_dir=str(out), + ) + data = json.loads(raw) + assert data["strict_success"] is False + assert "continue_diagram" in data["error"] + + @pytest.mark.asyncio + async def test_continue_diagram_missing_run(self, tmp_path: Path): + from mcp_server.server import _continue_run_mcp + from paperbanana.core.types import DiagramType + + out = tmp_path / "empty_outputs" + out.mkdir() + raw = await _continue_run_mcp( + expected=DiagramType.METHODOLOGY, + run_id="run_nope", + output_dir=str(out), + ) + data = json.loads(raw) + assert data["strict_success"] is False + assert "error" in data + + @pytest.mark.asyncio + async def test_continue_diagram_success_mocked_pipeline(self, tmp_path: Path, monkeypatch): + import mcp_server.server as mcp_server_mod + from mcp_server.server import _continue_run_mcp + from paperbanana.core.types import DiagramType, GenerationOutput + + out = self._write_resumable_run(tmp_path, diagram_type="methodology") + + class _FakePipeline: + def __init__(self, settings=None, progress_callback=None): + self.settings = settings + self._progress_callback = progress_callback + + async def continue_run( + self, + resume_state, + additional_iterations=None, + user_feedback=None, + progress_callback=None, + ): + final = tmp_path / "after_continue.png" + _write_png(final) + return GenerationOutput( + image_path=str(final), + description="final desc", + iterations=[], + metadata={"run_id": resume_state.run_id}, + ) + + monkeypatch.setattr(mcp_server_mod, "PaperBananaPipeline", _FakePipeline) + + raw = await _continue_run_mcp( + expected=DiagramType.METHODOLOGY, + run_id="run_mcp_test", + output_dir=str(out), + iterations=2, + ) + data = json.loads(raw) + assert data["strict_success"] is True + assert data["run_id"] == "run_mcp_test" + assert Path(data["final_image_path"]).name == "after_continue.png" + assert data["new_iteration_count"] == 0