Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mcp_server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ MCP server that exposes PaperBanana's diagram and plot generation as tools for C
|------|-------------|
| `generate_diagram` | Generate a methodology diagram from text context + caption |
| `generate_plot` | Generate a statistical plot from JSON data + intent description |
| `continue_diagram` | Continue a prior methodology `run_*` (more refinement and/or critic feedback); returns JSON paths |
| `continue_plot` | Continue a prior statistical-plot `run_*`; same JSON contract as `continue_diagram` |
| `evaluate_diagram` | Compare a generated diagram against a human reference (4 dimensions) |
| `evaluate_plot` | Compare a generated statistical plot against a human reference (4 dimensions) |
| `download_references` | Download the expanded reference set for stronger retrieval |
Expand All @@ -23,6 +25,10 @@ Long runs execute in a **worker thread** so they do not block the MCP server eve

On validation errors (missing manifest, bad flags), the JSON body includes `"error"` and `"strict_success": false`.

### Continue tools

`continue_diagram` and `continue_plot` mirror ``paperbanana generate --continue-run`` / Studio continue: they load `run_input.json` and the latest iteration under `output_dir` / `run_id`, then run more visualizer–critic rounds. Pick the tool that matches the run’s `diagram_type` (`methodology` vs `statistical_plot`); otherwise the response is `strict_success: false` with a hint to use the other tool. Successful responses include `final_image_path`, `run_dir`, and `metadata_path` when present.

## Installation

### Quick Install (via `uvx`)
Expand Down
265 changes: 265 additions & 0 deletions mcp_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
Tools:
generate_diagram — Generate a methodology diagram from text
generate_plot — Generate a statistical plot from JSON data
continue_diagram — Continue a prior methodology run (more refinement / feedback)
continue_plot — Continue a prior statistical-plot run
evaluate_diagram — Evaluate a generated diagram against a reference
evaluate_plot — Evaluate a generated plot against a reference
download_references — Download expanded reference set (~294 examples)
Expand All @@ -24,6 +26,7 @@
import os
from io import BytesIO
from pathlib import Path
from typing import Any

import structlog
from fastmcp import FastMCP
Expand All @@ -32,6 +35,7 @@

from paperbanana.core.config import Settings
from paperbanana.core.pipeline import PaperBananaPipeline
from paperbanana.core.resume import load_resume_state
from paperbanana.core.types import DiagramType, GenerationInput
from paperbanana.core.utils import detect_image_mime_type, find_prompt_dir
from paperbanana.core.workflow_runner import (
Expand Down Expand Up @@ -281,6 +285,116 @@ def _on_progress(event: str, payload: dict) -> None:
return Image(path=effective_path, format=fmt)


@mcp.tool
async def continue_diagram(
run_id: str,
output_dir: str = "outputs",
feedback: str | None = None,
iterations: int | None = None,
auto_refine: bool = False,
max_iterations: int | None = None,
config: str | None = None,
vlm_provider: str | None = None,
vlm_model: str | None = None,
image_provider: str | None = None,
image_model: str | None = None,
output_format: str = "png",
optimize: bool = False,
save_prompts: bool | None = None,
venue: str | None = None,
generate_caption: bool = False,
) -> str:
"""Continue a methodology diagram run under ``output_dir`` / ``run_id``.

Loads ``run_input.json`` and the latest iteration from an existing ``run_*``
directory (same as ``paperbanana generate --continue-run``). Runs more
visualizer–critic rounds without redoing retrieval / planner / stylist.

Args:
run_id: Directory name (e.g. ``run_20250109_120000_abc``).
output_dir: Base output directory containing the run folder.
feedback: Optional notes for the critic (same as CLI ``--feedback``).
iterations: Extra refinement rounds when ``auto_refine`` is false;
also sets ``refinement_iterations`` in settings when provided.
When omitted, uses the configured default iteration count.
auto_refine: If true, loop until the critic is satisfied (capped by
``max_iterations`` or settings default).
max_iterations: Cap when ``auto_refine`` is true.
config: Optional path to YAML config (same as other MCP tools).
vlm_provider, vlm_model, image_provider, image_model: Optional overrides.
output_format: png, jpeg, or webp.
optimize: Passed through to settings (normally unused for continue).
save_prompts, venue, generate_caption: Optional settings overrides.

Returns:
JSON string with ``strict_success``, paths, and ``new_iteration_count``;
on failure, ``strict_success`` is false and ``error`` explains why.
"""
return await _continue_run_mcp(
expected=DiagramType.METHODOLOGY,
run_id=run_id,
output_dir=output_dir,
feedback=feedback,
iterations=iterations,
auto_refine=auto_refine,
max_iterations=max_iterations,
config=config,
vlm_provider=vlm_provider,
vlm_model=vlm_model,
image_provider=image_provider,
image_model=image_model,
output_format=output_format,
optimize=optimize,
save_prompts=save_prompts,
venue=venue,
generate_caption=generate_caption,
)


@mcp.tool
async def continue_plot(
run_id: str,
output_dir: str = "outputs",
feedback: str | None = None,
iterations: int | None = None,
auto_refine: bool = False,
max_iterations: int | None = None,
config: str | None = None,
vlm_provider: str | None = None,
vlm_model: str | None = None,
image_provider: str | None = None,
image_model: str | None = None,
output_format: str = "png",
optimize: bool = False,
save_prompts: bool | None = None,
venue: str | None = None,
generate_caption: bool = False,
) -> str:
"""Continue a statistical-plot run (same contract as ``continue_diagram``).

Use when ``run_input.json`` has ``diagram_type`` ``statistical_plot``.
"""
return await _continue_run_mcp(
expected=DiagramType.STATISTICAL_PLOT,
run_id=run_id,
output_dir=output_dir,
feedback=feedback,
iterations=iterations,
auto_refine=auto_refine,
max_iterations=max_iterations,
config=config,
vlm_provider=vlm_provider,
vlm_model=vlm_model,
image_provider=image_provider,
image_model=image_model,
output_format=output_format,
optimize=optimize,
save_prompts=save_prompts,
venue=venue,
generate_caption=generate_caption,
)


@mcp.tool
async def evaluate_diagram(
generated_path: str,
Expand Down Expand Up @@ -417,6 +531,157 @@ def _json_result(payload: dict) -> str:
return json.dumps(payload, indent=2)


def _load_dotenv_best_effort() -> None:
try:
from dotenv import load_dotenv

load_dotenv()
except ImportError:
pass


def _mcp_settings_for_continue(
output_dir: str,
config: str | None,
*,
output_format: str = "png",
iterations: int | None = None,
auto_refine: bool = False,
max_iterations: int | None = None,
optimize: bool = False,
vlm_provider: str | None = None,
vlm_model: str | None = None,
image_provider: str | None = None,
image_model: str | None = None,
save_prompts: bool | None = None,
venue: str | None = None,
generate_caption: bool = False,
) -> Settings:
"""Build Settings for continue-run tools (mirrors CLI ``generate`` overrides)."""
overrides: dict[str, Any] = {
"output_dir": (output_dir or "outputs").strip() or "outputs",
"output_format": output_format.lower(),
"auto_refine": bool(auto_refine),
"optimize_inputs": bool(optimize),
"generate_caption": bool(generate_caption),
}
if iterations is not None:
overrides["refinement_iterations"] = max(1, int(iterations))
if max_iterations is not None:
overrides["max_iterations"] = max(1, int(max_iterations))
if vlm_provider:
overrides["vlm_provider"] = vlm_provider.strip()
if vlm_model:
overrides["vlm_model"] = vlm_model.strip()
if image_provider:
overrides["image_provider"] = image_provider.strip()
if image_model:
overrides["image_model"] = image_model.strip()
if save_prompts is not None:
overrides["save_prompts"] = save_prompts
if venue:
overrides["venue"] = venue.strip()

cfg = (config or "").strip()
if cfg:
return Settings.from_yaml(Path(cfg).expanduser(), **overrides)
_load_dotenv_best_effort()
return Settings(**overrides)


async def _continue_run_mcp(
*,
expected: DiagramType,
run_id: str,
output_dir: str = "outputs",
feedback: str | None = None,
iterations: int | None = None,
auto_refine: bool = False,
max_iterations: int | None = None,
config: str | None = None,
vlm_provider: str | None = None,
vlm_model: str | None = None,
image_provider: str | None = None,
image_model: str | None = None,
output_format: str = "png",
optimize: bool = False,
save_prompts: bool | None = None,
venue: str | None = None,
generate_caption: bool = False,
) -> str:
"""Shared implementation for ``continue_diagram`` / ``continue_plot``."""
tool_name = "continue_diagram" if expected == DiagramType.METHODOLOGY else "continue_plot"
settings = _mcp_settings_for_continue(
output_dir,
config,
output_format=output_format,
iterations=iterations,
auto_refine=auto_refine,
max_iterations=max_iterations,
optimize=optimize,
vlm_provider=vlm_provider,
vlm_model=vlm_model,
image_provider=image_provider,
image_model=image_model,
save_prompts=save_prompts,
venue=venue,
generate_caption=generate_caption,
)
try:
resume_state = load_resume_state(settings.output_dir, run_id.strip())
except (FileNotFoundError, ValueError) as e:
return _json_result({"error": str(e), "strict_success": False})

if resume_state.diagram_type != expected:
other = "continue_plot" if expected == DiagramType.METHODOLOGY else "continue_diagram"
return _json_result(
{
"error": (f"This run is {resume_state.diagram_type.value}; use {other} instead."),
"strict_success": False,
"actual_diagram_type": resume_state.diagram_type.value,
}
)

def _on_progress(event: str, payload: dict) -> None:
logger.info(
"mcp_progress",
tool=tool_name,
progress_event=event,
**payload,
)

try:
pipeline = PaperBananaPipeline(settings=settings, progress_callback=_on_progress)
result = await pipeline.continue_run(
resume_state=resume_state,
additional_iterations=iterations,
user_feedback=(feedback or "").strip() or None,
progress_callback=None,
)
except Exception as e:
logger.exception("mcp_continue_failed", tool=tool_name)
return _json_result({"error": str(e), "strict_success": False})

run_dir = Path(resume_state.run_dir)
meta = run_dir / "metadata.json"
payload: dict[str, Any] = {
"strict_success": True,
"run_id": resume_state.run_id,
"run_dir": str(run_dir.resolve()),
"final_image_path": str(Path(result.image_path).resolve()),
"metadata_path": str(meta.resolve()) if meta.is_file() else None,
"diagram_type": resume_state.diagram_type.value,
"new_iteration_count": len(result.iterations),
}
if result.vector_svg_path:
payload["vector_svg_path"] = str(Path(result.vector_svg_path).resolve())
if result.vector_pdf_path:
payload["vector_pdf_path"] = str(Path(result.vector_pdf_path).resolve())
if result.generated_caption:
payload["generated_caption"] = result.generated_caption
return _json_result(payload)


@mcp.tool
async def orchestrate_figures(
paper: str | None = None,
Expand Down
Loading
Loading