diff --git a/README.md b/README.md index 6f3538e4..c296d8f4 100644 --- a/README.md +++ b/README.md @@ -290,6 +290,12 @@ paperbanana batch-report --batch-id batch_20250109_123456_abc --format html --ou Diagram batch reports include `batch_kind: methodology`; plot batches use `batch_kind: statistical_plot`. Human-readable reports (`paperbanana batch-report`) show the batch kind when present. +**Sweep manifests** let you store the full sweep plan as YAML/JSON instead of eight comma-separated CLI flags. Mutually exclusive with the axis flags; see `examples/sweep_manifest.yaml`. + +```bash +paperbanana sweep --manifest examples/sweep_manifest.yaml +``` + **Sweep reports** produced by `paperbanana sweep` can be rendered the same way: ```bash diff --git a/examples/sweep_manifest.yaml b/examples/sweep_manifest.yaml new file mode 100644 index 00000000..2ba285fd --- /dev/null +++ b/examples/sweep_manifest.yaml @@ -0,0 +1,26 @@ +# Example sweep manifest — one input, cartesian product over the axes below. +# Invoke with: +# paperbanana sweep --manifest examples/sweep_manifest.yaml +# +# --manifest is mutually exclusive with the axis CLI flags +# (--vlm-providers, --vlm-models, --image-providers, --image-models, +# --iterations, --optimize-modes, --auto-modes). +# +# --output-dir, --config, --format, --dry-run, --verbose, --auto-download-data +# remain CLI-only (they are invocation concerns, not sweep-plan concerns). + +input: sample_inputs/transformer_method.txt +caption: "Overview of our encoder-decoder architecture with sparse routing" + +# Optional top-level keys: +# pdf_pages: "1-5" # PDF inputs only +# max_variants: 20 # cap the total cartesian product + +axes: + vlm_providers: [gemini, openai] + vlm_models: [] + image_providers: [google_imagen, openai_imagen] + image_models: [] + refinement_iterations: [2, 3] + optimize_inputs: [false, true] + auto_refine: [false] diff --git a/paperbanana/cli.py b/paperbanana/cli.py index 4fbbd48a..a56b2d25 100644 --- a/paperbanana/cli.py +++ b/paperbanana/cli.py @@ -834,18 +834,24 @@ def on_progress(event: PipelineProgressEvent) -> None: @app.command() def sweep( - input: str = typer.Option( - ..., + input: Optional[str] = typer.Option( + None, "--input", "-i", help="Path to methodology text file or PDF (.pdf requires: pip install 'paperbanana[pdf]')", ), - caption: str = typer.Option( - ..., + caption: Optional[str] = typer.Option( + None, "--caption", "-c", help="Figure caption / communicative intent", ), + manifest: Optional[str] = typer.Option( + None, + "--manifest", + "-m", + help="Path to sweep manifest (YAML or JSON). Mutually exclusive with axis flags.", + ), pdf_pages: Optional[str] = typer.Option( None, "--pdf-pages", @@ -929,8 +935,56 @@ def sweep( console.print("[red]Error: --max-variants must be >= 1[/red]") raise typer.Exit(1) + axis_flag_values = ( + vlm_providers, + vlm_models, + image_providers, + image_models, + iterations, + optimize_modes, + auto_modes, + ) + if manifest is not None and any(v is not None for v in axis_flag_values): + console.print( + "[red]Error: --manifest is mutually exclusive with axis flags " + "(--vlm-providers, --vlm-models, --image-providers, --image-models, " + "--iterations, --optimize-modes, --auto-modes)[/red]" + ) + raise typer.Exit(1) + if manifest is None and (not input or not caption): + console.print( + "[red]Error: --input and --caption are required unless --manifest is set[/red]" + ) + raise typer.Exit(1) + configure_logging(verbose=verbose) + from paperbanana.core.sweep import ( + build_sweep_variants, + load_sweep_manifest, + parse_csv_bools, + parse_csv_ints, + parse_csv_values, + quality_proxy_score, + rank_sweep_results, + summarize_sweep, + ) + + axes_from_manifest: dict[str, list] | None = None + if manifest is not None: + try: + parsed = load_sweep_manifest(Path(manifest)) + except (FileNotFoundError, ValueError, RuntimeError) as e: + console.print(f"[red]Error: {e}[/red]") + raise typer.Exit(1) + input = parsed["input"] + caption = parsed["caption"] + if parsed["pdf_pages"] is not None: + pdf_pages = parsed["pdf_pages"] + if parsed["max_variants"] is not None: + max_variants = parsed["max_variants"] + axes_from_manifest = parsed["axes"] + input_path = Path(input) if not input_path.exists(): console.print(f"[red]Error: Input file not found: {input}[/red]") @@ -942,27 +996,30 @@ def sweep( load_dotenv() from paperbanana.core.source_loader import load_methodology_source - from paperbanana.core.sweep import ( - build_sweep_variants, - parse_csv_bools, - parse_csv_ints, - parse_csv_values, - quality_proxy_score, - rank_sweep_results, - summarize_sweep, - ) try: - variant_list = build_sweep_variants( - vlm_providers=parse_csv_values(vlm_providers), - vlm_models=parse_csv_values(vlm_models), - image_providers=parse_csv_values(image_providers), - image_models=parse_csv_values(image_models), - refinement_iterations=parse_csv_ints(iterations, field_name="--iterations"), - optimize_inputs=parse_csv_bools(optimize_modes, field_name="--optimize-modes"), - auto_refine=parse_csv_bools(auto_modes, field_name="--auto-modes"), - max_variants=max_variants, - ) + if axes_from_manifest is not None: + variant_list = build_sweep_variants( + vlm_providers=[str(x) for x in axes_from_manifest["vlm_providers"]], + vlm_models=[str(x) for x in axes_from_manifest["vlm_models"]], + image_providers=[str(x) for x in axes_from_manifest["image_providers"]], + image_models=[str(x) for x in axes_from_manifest["image_models"]], + refinement_iterations=[int(x) for x in axes_from_manifest["refinement_iterations"]], + optimize_inputs=[bool(x) for x in axes_from_manifest["optimize_inputs"]], + auto_refine=[bool(x) for x in axes_from_manifest["auto_refine"]], + max_variants=max_variants, + ) + else: + variant_list = build_sweep_variants( + vlm_providers=parse_csv_values(vlm_providers), + vlm_models=parse_csv_values(vlm_models), + image_providers=parse_csv_values(image_providers), + image_models=parse_csv_values(image_models), + refinement_iterations=parse_csv_ints(iterations, field_name="--iterations"), + optimize_inputs=parse_csv_bools(optimize_modes, field_name="--optimize-modes"), + auto_refine=parse_csv_bools(auto_modes, field_name="--auto-modes"), + max_variants=max_variants, + ) except ValueError as e: console.print(f"[red]Error: {e}[/red]") raise typer.Exit(1) diff --git a/paperbanana/core/sweep.py b/paperbanana/core/sweep.py index e84472bf..39fac079 100644 --- a/paperbanana/core/sweep.py +++ b/paperbanana/core/sweep.py @@ -106,6 +106,97 @@ def parse_csv_bools(raw: str | None, *, field_name: str) -> list[bool]: return parsed +SWEEP_MANIFEST_REQUIRED_KEYS = ("input", "caption") +SWEEP_MANIFEST_AXIS_KEYS = ( + "vlm_providers", + "vlm_models", + "image_providers", + "image_models", + "refinement_iterations", + "optimize_inputs", + "auto_refine", +) + + +def load_sweep_manifest(manifest_path: Path) -> dict[str, Any]: + """Load a sweep manifest (YAML or JSON) and return a normalized dict. + + Required top-level keys: ``input`` (file path) and ``caption`` (string). + Optional top-level keys: ``pdf_pages`` (str), ``max_variants`` (int), and + ``axes`` (object mapping the seven axis keys — see ``SWEEP_MANIFEST_AXIS_KEYS`` + — to lists). + + The ``input`` path is resolved relative to the manifest's parent directory + when not absolute (matching ``load_batch_manifest``). + """ + manifest_path = Path(manifest_path).resolve() + if not manifest_path.exists(): + raise FileNotFoundError(f"Sweep manifest not found: {manifest_path}") + raw = manifest_path.read_text(encoding="utf-8") + suffix = manifest_path.suffix.lower() + if suffix in (".yaml", ".yml"): + try: + import yaml + + data = yaml.safe_load(raw) + except ImportError as exc: + raise RuntimeError( + "PyYAML is required for YAML manifests. Install with: pip install pyyaml" + ) from exc + elif suffix == ".json": + data = json.loads(raw) + else: + raise ValueError(f"Manifest must be .yaml, .yml, or .json. Got: {manifest_path.suffix}") + + if not isinstance(data, dict): + raise ValueError("Sweep manifest must be a mapping at the top level") + for key in SWEEP_MANIFEST_REQUIRED_KEYS: + if not data.get(key): + raise ValueError(f"Sweep manifest is missing required key: '{key}'") + if not isinstance(data["caption"], str): + raise ValueError("Sweep manifest 'caption' must be a string") + pdf_pages = data.get("pdf_pages") + if pdf_pages is not None and not isinstance(pdf_pages, str): + raise ValueError("Sweep manifest 'pdf_pages' must be a string when set") + max_variants = data.get("max_variants") + if max_variants is not None: + if not isinstance(max_variants, int) or isinstance(max_variants, bool): + raise ValueError("Sweep manifest 'max_variants' must be an integer when set") + if max_variants < 1: + raise ValueError("Sweep manifest 'max_variants' must be >= 1") + + axes_raw = data.get("axes") or {} + if not isinstance(axes_raw, dict): + raise ValueError("Sweep manifest 'axes' must be a mapping when set") + unknown = set(axes_raw) - set(SWEEP_MANIFEST_AXIS_KEYS) + if unknown: + raise ValueError( + f"Sweep manifest 'axes' has unknown keys: {sorted(unknown)}. " + f"Allowed: {list(SWEEP_MANIFEST_AXIS_KEYS)}" + ) + axes: dict[str, list[Any]] = {} + for key in SWEEP_MANIFEST_AXIS_KEYS: + value = axes_raw.get(key) + if value is None: + axes[key] = [] + continue + if not isinstance(value, list): + raise ValueError(f"Sweep manifest axis '{key}' must be a list") + axes[key] = value + + input_path = Path(data["input"]) + if not input_path.is_absolute(): + input_path = (manifest_path.parent / input_path).resolve() + + return { + "input": str(input_path), + "caption": data["caption"], + "pdf_pages": pdf_pages, + "max_variants": max_variants, + "axes": axes, + } + + def build_sweep_variants( *, vlm_providers: list[str], diff --git a/tests/test_core/test_sweep.py b/tests/test_core/test_sweep.py index e2caec1c..59e3f72f 100644 --- a/tests/test_core/test_sweep.py +++ b/tests/test_core/test_sweep.py @@ -12,6 +12,7 @@ build_sweep_variants, generate_sweep_report_html, generate_sweep_report_md, + load_sweep_manifest, load_sweep_report, parse_csv_bools, parse_csv_ints, @@ -477,3 +478,162 @@ def test_write_sweep_report_html_passes_through_thumbnails_flag(tmp_path: Path) (tmp_path / SWEEP_REPORT_FILENAME).write_text(json.dumps(payload), encoding="utf-8") written = write_sweep_report(tmp_path, format="html", include_thumbnails=False) assert 'class="thumb-grid"' not in written.read_text(encoding="utf-8") + + +# --------------------------------------------------------------------------- +# load_sweep_manifest +# --------------------------------------------------------------------------- + + +def _write_manifest_yaml(tmp_path: Path, body: str) -> Path: + p = tmp_path / "sweep.yaml" + p.write_text(body, encoding="utf-8") + return p + + +def test_load_sweep_manifest_yaml_success(tmp_path: Path) -> None: + input_file = tmp_path / "method.txt" + input_file.write_text("method body", encoding="utf-8") + manifest = _write_manifest_yaml( + tmp_path, + """ +input: method.txt +caption: "Test caption" +pdf_pages: "1-3" +max_variants: 5 +axes: + vlm_providers: [gemini, openai] + refinement_iterations: [2, 3] + optimize_inputs: [false, true] +""", + ) + parsed = load_sweep_manifest(manifest) + assert parsed["input"] == str(input_file.resolve()) + assert parsed["caption"] == "Test caption" + assert parsed["pdf_pages"] == "1-3" + assert parsed["max_variants"] == 5 + assert parsed["axes"]["vlm_providers"] == ["gemini", "openai"] + assert parsed["axes"]["refinement_iterations"] == [2, 3] + assert parsed["axes"]["optimize_inputs"] == [False, True] + # Unset axes default to empty lists. + assert parsed["axes"]["auto_refine"] == [] + assert parsed["axes"]["image_models"] == [] + + +def test_load_sweep_manifest_json_success(tmp_path: Path) -> None: + (tmp_path / "method.txt").write_text("x", encoding="utf-8") + payload = { + "input": "method.txt", + "caption": "A caption", + "axes": {"vlm_providers": ["gemini"]}, + } + path = tmp_path / "sweep.json" + path.write_text(json.dumps(payload), encoding="utf-8") + parsed = load_sweep_manifest(path) + assert parsed["caption"] == "A caption" + assert parsed["axes"]["vlm_providers"] == ["gemini"] + assert parsed["pdf_pages"] is None + assert parsed["max_variants"] is None + + +def test_load_sweep_manifest_resolves_relative_input(tmp_path: Path) -> None: + nested = tmp_path / "nested" + nested.mkdir() + input_file = nested / "method.txt" + input_file.write_text("body", encoding="utf-8") + manifest = _write_manifest_yaml( + tmp_path, + 'input: nested/method.txt\ncaption: "Test"\n', + ) + parsed = load_sweep_manifest(manifest) + assert parsed["input"] == str(input_file.resolve()) + + +def test_load_sweep_manifest_missing_file() -> None: + with pytest.raises(FileNotFoundError, match="Sweep manifest not found"): + load_sweep_manifest(Path("/nonexistent/sweep.yaml")) + + +def test_load_sweep_manifest_unsupported_suffix(tmp_path: Path) -> None: + p = tmp_path / "sweep.txt" + p.write_text("input: x\ncaption: y", encoding="utf-8") + with pytest.raises(ValueError, match=".yaml, .yml, or .json"): + load_sweep_manifest(p) + + +def test_load_sweep_manifest_top_level_not_mapping(tmp_path: Path) -> None: + p = _write_manifest_yaml(tmp_path, "- just-a-list\n") + with pytest.raises(ValueError, match="mapping at the top level"): + load_sweep_manifest(p) + + +def test_load_sweep_manifest_missing_input(tmp_path: Path) -> None: + p = _write_manifest_yaml(tmp_path, 'caption: "Test"\n') + with pytest.raises(ValueError, match="missing required key: 'input'"): + load_sweep_manifest(p) + + +def test_load_sweep_manifest_missing_caption(tmp_path: Path) -> None: + p = _write_manifest_yaml(tmp_path, "input: method.txt\n") + with pytest.raises(ValueError, match="missing required key: 'caption'"): + load_sweep_manifest(p) + + +def test_load_sweep_manifest_caption_wrong_type(tmp_path: Path) -> None: + p = _write_manifest_yaml(tmp_path, "input: method.txt\ncaption: 42\n") + with pytest.raises(ValueError, match="'caption' must be a string"): + load_sweep_manifest(p) + + +def test_load_sweep_manifest_pdf_pages_wrong_type(tmp_path: Path) -> None: + p = _write_manifest_yaml( + tmp_path, + 'input: method.txt\ncaption: "Test"\npdf_pages: 42\n', + ) + with pytest.raises(ValueError, match="'pdf_pages' must be a string"): + load_sweep_manifest(p) + + +def test_load_sweep_manifest_max_variants_not_int(tmp_path: Path) -> None: + p = _write_manifest_yaml( + tmp_path, + 'input: method.txt\ncaption: "Test"\nmax_variants: "lots"\n', + ) + with pytest.raises(ValueError, match="'max_variants' must be an integer"): + load_sweep_manifest(p) + + +def test_load_sweep_manifest_max_variants_zero(tmp_path: Path) -> None: + p = _write_manifest_yaml( + tmp_path, + 'input: method.txt\ncaption: "Test"\nmax_variants: 0\n', + ) + with pytest.raises(ValueError, match="'max_variants' must be >= 1"): + load_sweep_manifest(p) + + +def test_load_sweep_manifest_axes_wrong_type(tmp_path: Path) -> None: + p = _write_manifest_yaml( + tmp_path, + 'input: method.txt\ncaption: "Test"\naxes: "not a dict"\n', + ) + with pytest.raises(ValueError, match="'axes' must be a mapping"): + load_sweep_manifest(p) + + +def test_load_sweep_manifest_axis_value_not_list(tmp_path: Path) -> None: + p = _write_manifest_yaml( + tmp_path, + 'input: method.txt\ncaption: "Test"\naxes:\n vlm_providers: gemini\n', + ) + with pytest.raises(ValueError, match="axis 'vlm_providers' must be a list"): + load_sweep_manifest(p) + + +def test_load_sweep_manifest_unknown_axis_rejected(tmp_path: Path) -> None: + p = _write_manifest_yaml( + tmp_path, + 'input: method.txt\ncaption: "Test"\naxes:\n typo_axis: [a, b]\n', + ) + with pytest.raises(ValueError, match="unknown keys"): + load_sweep_manifest(p)