Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,12 @@ paperbanana batch-report --batch-id batch_20250109_123456_abc --format html --ou

Diagram batch reports include `batch_kind: methodology`; plot batches use `batch_kind: statistical_plot`. Human-readable reports (`paperbanana batch-report`) show the batch kind when present.

**Sweep manifests** let you store the full sweep plan as YAML/JSON instead of eight comma-separated CLI flags. Mutually exclusive with the axis flags; see `examples/sweep_manifest.yaml`.

```bash
paperbanana sweep --manifest examples/sweep_manifest.yaml
```

**Sweep reports** produced by `paperbanana sweep` can be rendered the same way:

```bash
Expand Down
26 changes: 26 additions & 0 deletions examples/sweep_manifest.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Example sweep manifest — one input, cartesian product over the axes below.
# Invoke with:
# paperbanana sweep --manifest examples/sweep_manifest.yaml
#
# --manifest is mutually exclusive with the axis CLI flags
# (--vlm-providers, --vlm-models, --image-providers, --image-models,
# --iterations, --optimize-modes, --auto-modes).
#
# --output-dir, --config, --format, --dry-run, --verbose, --auto-download-data
# remain CLI-only (they are invocation concerns, not sweep-plan concerns).

input: sample_inputs/transformer_method.txt
caption: "Overview of our encoder-decoder architecture with sparse routing"

# Optional top-level keys:
# pdf_pages: "1-5" # PDF inputs only
# max_variants: 20 # cap the total cartesian product

axes:
vlm_providers: [gemini, openai]
vlm_models: []
image_providers: [google_imagen, openai_imagen]
image_models: []
refinement_iterations: [2, 3]
optimize_inputs: [false, true]
auto_refine: [false]
103 changes: 80 additions & 23 deletions paperbanana/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,18 +834,24 @@ def on_progress(event: PipelineProgressEvent) -> None:

@app.command()
def sweep(
input: str = typer.Option(
...,
input: Optional[str] = typer.Option(
None,
"--input",
"-i",
help="Path to methodology text file or PDF (.pdf requires: pip install 'paperbanana[pdf]')",
),
caption: str = typer.Option(
...,
caption: Optional[str] = typer.Option(
None,
"--caption",
"-c",
help="Figure caption / communicative intent",
),
manifest: Optional[str] = typer.Option(
None,
"--manifest",
"-m",
help="Path to sweep manifest (YAML or JSON). Mutually exclusive with axis flags.",
),
pdf_pages: Optional[str] = typer.Option(
None,
"--pdf-pages",
Expand Down Expand Up @@ -929,8 +935,56 @@ def sweep(
console.print("[red]Error: --max-variants must be >= 1[/red]")
raise typer.Exit(1)

axis_flag_values = (
vlm_providers,
vlm_models,
image_providers,
image_models,
iterations,
optimize_modes,
auto_modes,
)
if manifest is not None and any(v is not None for v in axis_flag_values):
console.print(
"[red]Error: --manifest is mutually exclusive with axis flags "
"(--vlm-providers, --vlm-models, --image-providers, --image-models, "
"--iterations, --optimize-modes, --auto-modes)[/red]"
)
raise typer.Exit(1)
if manifest is None and (not input or not caption):
console.print(
"[red]Error: --input and --caption are required unless --manifest is set[/red]"
)
raise typer.Exit(1)

configure_logging(verbose=verbose)

from paperbanana.core.sweep import (
build_sweep_variants,
load_sweep_manifest,
parse_csv_bools,
parse_csv_ints,
parse_csv_values,
quality_proxy_score,
rank_sweep_results,
summarize_sweep,
)

axes_from_manifest: dict[str, list] | None = None
if manifest is not None:
try:
parsed = load_sweep_manifest(Path(manifest))
except (FileNotFoundError, ValueError, RuntimeError) as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1)
input = parsed["input"]
caption = parsed["caption"]
if parsed["pdf_pages"] is not None:
pdf_pages = parsed["pdf_pages"]
if parsed["max_variants"] is not None:
max_variants = parsed["max_variants"]
axes_from_manifest = parsed["axes"]

input_path = Path(input)
if not input_path.exists():
console.print(f"[red]Error: Input file not found: {input}[/red]")
Expand All @@ -942,27 +996,30 @@ def sweep(
load_dotenv()

from paperbanana.core.source_loader import load_methodology_source
from paperbanana.core.sweep import (
build_sweep_variants,
parse_csv_bools,
parse_csv_ints,
parse_csv_values,
quality_proxy_score,
rank_sweep_results,
summarize_sweep,
)

try:
variant_list = build_sweep_variants(
vlm_providers=parse_csv_values(vlm_providers),
vlm_models=parse_csv_values(vlm_models),
image_providers=parse_csv_values(image_providers),
image_models=parse_csv_values(image_models),
refinement_iterations=parse_csv_ints(iterations, field_name="--iterations"),
optimize_inputs=parse_csv_bools(optimize_modes, field_name="--optimize-modes"),
auto_refine=parse_csv_bools(auto_modes, field_name="--auto-modes"),
max_variants=max_variants,
)
if axes_from_manifest is not None:
variant_list = build_sweep_variants(
vlm_providers=[str(x) for x in axes_from_manifest["vlm_providers"]],
vlm_models=[str(x) for x in axes_from_manifest["vlm_models"]],
image_providers=[str(x) for x in axes_from_manifest["image_providers"]],
image_models=[str(x) for x in axes_from_manifest["image_models"]],
refinement_iterations=[int(x) for x in axes_from_manifest["refinement_iterations"]],
optimize_inputs=[bool(x) for x in axes_from_manifest["optimize_inputs"]],
auto_refine=[bool(x) for x in axes_from_manifest["auto_refine"]],
max_variants=max_variants,
)
else:
variant_list = build_sweep_variants(
vlm_providers=parse_csv_values(vlm_providers),
vlm_models=parse_csv_values(vlm_models),
image_providers=parse_csv_values(image_providers),
image_models=parse_csv_values(image_models),
refinement_iterations=parse_csv_ints(iterations, field_name="--iterations"),
optimize_inputs=parse_csv_bools(optimize_modes, field_name="--optimize-modes"),
auto_refine=parse_csv_bools(auto_modes, field_name="--auto-modes"),
max_variants=max_variants,
)
except ValueError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1)
Expand Down
91 changes: 91 additions & 0 deletions paperbanana/core/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,97 @@ def parse_csv_bools(raw: str | None, *, field_name: str) -> list[bool]:
return parsed


SWEEP_MANIFEST_REQUIRED_KEYS = ("input", "caption")
SWEEP_MANIFEST_AXIS_KEYS = (
"vlm_providers",
"vlm_models",
"image_providers",
"image_models",
"refinement_iterations",
"optimize_inputs",
"auto_refine",
)


def load_sweep_manifest(manifest_path: Path) -> dict[str, Any]:
"""Load a sweep manifest (YAML or JSON) and return a normalized dict.

Required top-level keys: ``input`` (file path) and ``caption`` (string).
Optional top-level keys: ``pdf_pages`` (str), ``max_variants`` (int), and
``axes`` (object mapping the seven axis keys — see ``SWEEP_MANIFEST_AXIS_KEYS``
— to lists).

The ``input`` path is resolved relative to the manifest's parent directory
when not absolute (matching ``load_batch_manifest``).
"""
manifest_path = Path(manifest_path).resolve()
if not manifest_path.exists():
raise FileNotFoundError(f"Sweep manifest not found: {manifest_path}")
raw = manifest_path.read_text(encoding="utf-8")
suffix = manifest_path.suffix.lower()
if suffix in (".yaml", ".yml"):
try:
import yaml

data = yaml.safe_load(raw)
except ImportError as exc:
raise RuntimeError(
"PyYAML is required for YAML manifests. Install with: pip install pyyaml"
) from exc
elif suffix == ".json":
data = json.loads(raw)
else:
raise ValueError(f"Manifest must be .yaml, .yml, or .json. Got: {manifest_path.suffix}")

if not isinstance(data, dict):
raise ValueError("Sweep manifest must be a mapping at the top level")
for key in SWEEP_MANIFEST_REQUIRED_KEYS:
if not data.get(key):
raise ValueError(f"Sweep manifest is missing required key: '{key}'")
if not isinstance(data["caption"], str):
raise ValueError("Sweep manifest 'caption' must be a string")
pdf_pages = data.get("pdf_pages")
if pdf_pages is not None and not isinstance(pdf_pages, str):
raise ValueError("Sweep manifest 'pdf_pages' must be a string when set")
max_variants = data.get("max_variants")
if max_variants is not None:
if not isinstance(max_variants, int) or isinstance(max_variants, bool):
raise ValueError("Sweep manifest 'max_variants' must be an integer when set")
if max_variants < 1:
raise ValueError("Sweep manifest 'max_variants' must be >= 1")

axes_raw = data.get("axes") or {}
if not isinstance(axes_raw, dict):
raise ValueError("Sweep manifest 'axes' must be a mapping when set")
unknown = set(axes_raw) - set(SWEEP_MANIFEST_AXIS_KEYS)
if unknown:
raise ValueError(
f"Sweep manifest 'axes' has unknown keys: {sorted(unknown)}. "
f"Allowed: {list(SWEEP_MANIFEST_AXIS_KEYS)}"
)
axes: dict[str, list[Any]] = {}
for key in SWEEP_MANIFEST_AXIS_KEYS:
value = axes_raw.get(key)
if value is None:
axes[key] = []
continue
if not isinstance(value, list):
raise ValueError(f"Sweep manifest axis '{key}' must be a list")
axes[key] = value

input_path = Path(data["input"])
if not input_path.is_absolute():
input_path = (manifest_path.parent / input_path).resolve()

return {
"input": str(input_path),
"caption": data["caption"],
"pdf_pages": pdf_pages,
"max_variants": max_variants,
"axes": axes,
}


def build_sweep_variants(
*,
vlm_providers: list[str],
Expand Down
Loading
Loading