diff --git a/paperbanana/studio/app.py b/paperbanana/studio/app.py index 13bf7372..3e1f92fa 100644 --- a/paperbanana/studio/app.py +++ b/paperbanana/studio/app.py @@ -19,8 +19,10 @@ run_continue, run_evaluate, run_methodology, + run_orchestrate, run_plot, run_plot_batch, + run_sweep, ) @@ -774,6 +776,272 @@ def _do_composite( outputs=[cmp_log, cmp_out], ) + # ── Sweep ───────────────────────────────────────────────────── + with gr.Tab("Sweep"): + gr.Markdown( + "Run a multi-variant methodology sweep and rank outputs by the built-in " + "quality proxy score. This currently uses a source file input." + ) + sw_input = gr.File( + label="Methodology source file", + file_types=[".txt", ".md", ".pdf"], + ) + sw_caption = gr.Textbox(label="Figure caption / communicative intent", lines=2) + sw_pdf_pages = gr.Textbox( + label="PDF pages (optional)", + placeholder="e.g. 1-5,7 (PDF inputs only)", + ) + with gr.Row(): + sw_vlm_providers = gr.Textbox( + label="VLM providers", + value="", + placeholder="Comma-separated, e.g. gemini,openai", + ) + sw_vlm_models = gr.Textbox( + label="VLM models", + value="", + placeholder="Optional comma-separated models", + ) + with gr.Row(): + sw_img_providers = gr.Textbox( + label="Image providers", + value="", + placeholder="Comma-separated, e.g. google_imagen,openai_imagen", + ) + sw_img_models = gr.Textbox( + label="Image models", + value="", + placeholder="Optional comma-separated models", + ) + with gr.Row(): + sw_iterations = gr.Textbox( + label="Iteration axis", + value="", + placeholder="Comma-separated ints, e.g. 2,3,4", + ) + sw_opt_modes = gr.Textbox( + label="Optimize axis", + value="", + placeholder="on,off", + ) + sw_auto_modes = gr.Textbox( + label="Auto-refine axis", + value="", + placeholder="off,on", + ) + with gr.Row(): + sw_max_variants = gr.Number( + label="Max variants (optional)", + value=None, + precision=0, + ) + sw_dry_run = gr.Checkbox(label="Dry run (plan only)", value=False) + sw_log = gr.Textbox(label="Sweep log", lines=16) + sw_dir = gr.Textbox(label="Sweep output directory") + sw_report = gr.Textbox(label="sweep_report.json path") + sw_go = gr.Button("Run sweep", variant="primary") + + def _do_sweep( + od, + c, + vp, + vm, + ip, + im, + fo, + it, + au, + mx, + op, + sp, + sd, + infile, + caption, + pdf_pages, + svp, + svm, + sip, + sim, + siters, + sopt, + sauto, + smax, + sdry, + ): + _dotenv() + try: + path = _upload_path(infile) + if not path: + return "Upload a methodology source file.", "", "" + st = _settings(od, c, vp, vm, ip, im, fo, it, au, mx, op, sp, sd) + max_variants_int: Optional[int] = None + if smax is not None and not (isinstance(smax, float) and math.isnan(smax)): + max_variants_int = int(smax) + log, sweep_dir, report_path = run_sweep( + st, + input_path=path, + caption=caption or "", + pdf_pages=(pdf_pages or "").strip() or None, + vlm_providers=svp or "", + vlm_models=svm or "", + image_providers=sip or "", + image_models=sim or "", + iterations=siters or "", + optimize_modes=sopt or "", + auto_modes=sauto or "", + max_variants=max_variants_int, + dry_run=bool(sdry), + verbose_logging=False, + ) + return log, sweep_dir, report_path + except Exception as e: + return f"{type(e).__name__}: {e}", "", "" + + sw_go.click( + _do_sweep, + inputs=[ + out_dir, + cfg, + vlm_p, + vlm_m, + img_p, + img_m, + fmt, + iters, + auto_ref, + max_it, + opt_in, + save_pr, + seed_val, + sw_input, + sw_caption, + sw_pdf_pages, + sw_vlm_providers, + sw_vlm_models, + sw_img_providers, + sw_img_models, + sw_iterations, + sw_opt_modes, + sw_auto_modes, + sw_max_variants, + sw_dry_run, + ], + outputs=[sw_log, sw_dir, sw_report], + ) + + # ── Orchestrate ─────────────────────────────────────────────── + with gr.Tab("Orchestrate"): + gr.Markdown( + "Generate a full figure package from a paper path or resume an existing " + "orchestration run." + ) + or_paper = gr.File( + label="Paper file (required for new orchestration)", + file_types=[".pdf", ".txt", ".md"], + ) + or_resume = gr.Textbox( + label="Resume orchestration (ID or path)", + placeholder="Optional: orchestrate_... or /path/to/orchestrate_dir", + ) + with gr.Row(): + or_data_dir = gr.Textbox( + label="Data directory (optional, new run only)", + placeholder="Folder with CSV/JSON for plot extraction", + ) + or_pdf_pages = gr.Textbox( + label="PDF pages (optional, new run only)", + placeholder="e.g. 1-6,9", + ) + with gr.Row(): + or_max_method = gr.Number(label="Max methodology figures", value=6, precision=0) + or_max_plot = gr.Number(label="Max plot figures", value=4, precision=0) + with gr.Row(): + or_retry_failed = gr.Checkbox(label="Retry failed items", value=False) + or_dry_run = gr.Checkbox(label="Dry run (plan only)", value=False) + with gr.Row(): + or_max_retries = gr.Number(label="Max retries per item", value=0, precision=0) + or_concurrency = gr.Number(label="Concurrency", value=1, precision=0) + or_log = gr.Textbox(label="Orchestration log", lines=16) + or_summary = gr.Markdown() + or_go = gr.Button("Run orchestration", variant="primary") + + def _do_orchestrate( + od, + c, + vp, + vm, + ip, + im, + fo, + it, + au, + mx, + op, + sp, + sd, + paper_file, + resume_ref, + data_dir, + pdf_pages, + max_method, + max_plot, + retry_failed, + dry_run, + max_retries, + concurrency, + ): + _dotenv() + try: + st = _settings(od, c, vp, vm, ip, im, fo, it, au, mx, op, sp, sd) + paper_path = (_upload_path(paper_file) or "").strip() or None + resume_val = (resume_ref or "").strip() or None + log, summary = run_orchestrate( + st, + paper=paper_path, + resume_orchestrate=resume_val, + data_dir=(data_dir or "").strip() or None, + max_method_figures=max(1, int(max_method or 1)), + max_plot_figures=max(0, int(max_plot or 0)), + pdf_pages=(pdf_pages or "").strip() or None, + dry_run=bool(dry_run), + retry_failed=bool(retry_failed), + max_retries=max(0, int(max_retries or 0)), + concurrency=max(1, int(concurrency or 1)), + ) + return log, summary + except Exception as e: + return f"{type(e).__name__}: {e}", f"FAILED: {e}" + + or_go.click( + _do_orchestrate, + inputs=[ + out_dir, + cfg, + vlm_p, + vlm_m, + img_p, + img_m, + fmt, + iters, + auto_ref, + max_it, + opt_in, + save_pr, + seed_val, + or_paper, + or_resume, + or_data_dir, + or_pdf_pages, + or_max_method, + or_max_plot, + or_retry_failed, + or_dry_run, + or_max_retries, + or_concurrency, + ], + outputs=[or_log, or_summary], + ) + # ── Runs browser ────────────────────────────────────────────── with gr.Tab("Runs"): gr.Markdown("Inspect previous **run_*** and **batch_*** directories.") diff --git a/paperbanana/studio/runner.py b/paperbanana/studio/runner.py index 6088e96c..5dbb81e2 100644 --- a/paperbanana/studio/runner.py +++ b/paperbanana/studio/runner.py @@ -24,13 +24,24 @@ from paperbanana.core.pipeline import PaperBananaPipeline from paperbanana.core.plot_data import load_statistical_plot_payload from paperbanana.core.resume import load_resume_state +from paperbanana.core.source_loader import load_methodology_source +from paperbanana.core.sweep import ( + build_sweep_variants, + parse_csv_bools, + parse_csv_ints, + parse_csv_values, + quality_proxy_score, + rank_sweep_results, + summarize_sweep, +) from paperbanana.core.types import ( DiagramType, GenerationInput, PipelineProgressEvent, PipelineProgressStage, ) -from paperbanana.core.utils import ensure_dir, find_prompt_dir +from paperbanana.core.utils import ensure_dir, find_prompt_dir, generate_run_id, save_json +from paperbanana.core.workflow_runner import run_orchestration_package from paperbanana.evaluation.judge import VLMJudge from paperbanana.providers.registry import ProviderRegistry @@ -685,6 +696,253 @@ async def _run_one(idx: int, item: dict[str, Any]) -> None: return "\n".join(lines), str(batch_dir.resolve()) +def run_sweep( + settings: Settings, + *, + input_path: str, + caption: str, + pdf_pages: Optional[str] = None, + vlm_providers: str = "", + vlm_models: str = "", + image_providers: str = "", + image_models: str = "", + iterations: str = "", + optimize_modes: str = "", + auto_modes: str = "", + max_variants: Optional[int] = None, + dry_run: bool = False, + verbose_logging: bool = False, +) -> tuple[str, str, str]: + """Run sweep using core sweep utilities. Returns (log, sweep_dir, report_path).""" + configure_logging(verbose=verbose_logging) + lines: list[str] = ["Starting parameter sweep..."] + input_file = Path(input_path) + if not input_file.is_file(): + msg = f"Input file not found: {input_path}" + lines.append(msg) + return "\n".join(lines), "", "" + if not caption.strip(): + msg = "Caption is required." + lines.append(msg) + return "\n".join(lines), "", "" + if max_variants is not None and max_variants < 1: + msg = "max_variants must be >= 1" + lines.append(msg) + return "\n".join(lines), "", "" + + try: + variants = build_sweep_variants( + vlm_providers=parse_csv_values(vlm_providers), + vlm_models=parse_csv_values(vlm_models), + image_providers=parse_csv_values(image_providers), + image_models=parse_csv_values(image_models), + refinement_iterations=parse_csv_ints(iterations, field_name="iterations"), + optimize_inputs=parse_csv_bools(optimize_modes, field_name="optimize_modes"), + auto_refine=parse_csv_bools(auto_modes, field_name="auto_modes"), + max_variants=max_variants, + ) + except ValueError as e: + lines.append(str(e)) + return "\n".join(lines), "", "" + if not variants: + lines.append("Sweep generated zero variants.") + return "\n".join(lines), "", "" + + try: + source_context = load_methodology_source(input_file, pdf_pages=pdf_pages) + except Exception as e: + lines.append(f"{type(e).__name__}: {e}") + return "\n".join(lines), "", "" + + sweep_id = f"sweep_{generate_run_id()}" + sweep_dir = ensure_dir(Path(settings.output_dir) / sweep_id) + report_path = sweep_dir / "sweep_report.json" + lines.append(f"Sweep ID: {sweep_id}") + lines.append(f"Variants: {len(variants)}") + lines.append(f"Output: {sweep_dir}") + + if dry_run: + preview = [variant.as_dict() for variant in variants[: min(10, len(variants))]] + report = { + "sweep_id": sweep_id, + "status": "dry_run", + "input": str(input_file.resolve()), + "caption": caption, + "total_variants": len(variants), + "preview": preview, + } + save_json(report, report_path) + lines.append("Dry run complete.") + lines.append(f"Report written: {report_path}") + return "\n".join(lines), str(sweep_dir), str(report_path) + + all_results: list[dict[str, Any]] = [] + total_start = time.perf_counter() + gen_input = GenerationInput( + source_context=source_context, + communicative_intent=caption.strip(), + diagram_type=DiagramType.METHODOLOGY, + ) + + for idx, variant in enumerate(variants, start=1): + lines.append(f"Variant {idx}/{len(variants)} — {variant.variant_id}") + variant_dir = ensure_dir(sweep_dir / variant.variant_id) + overrides: dict[str, Any] = { + "output_dir": str(variant_dir), + "output_format": settings.output_format, + "vlm_provider": variant.vlm_provider, + "image_provider": variant.image_provider, + "refinement_iterations": variant.refinement_iterations, + "optimize_inputs": variant.optimize_inputs, + "auto_refine": variant.auto_refine, + } + if variant.vlm_model: + overrides["vlm_model"] = variant.vlm_model + if variant.image_model: + overrides["image_model"] = variant.image_model + variant_settings = settings.model_copy(update=overrides) + try: + variant_start = time.perf_counter() + result = asyncio.run(PaperBananaPipeline(settings=variant_settings).generate(gen_input)) + variant_seconds = time.perf_counter() - variant_start + final_critique = result.iterations[-1].critique if result.iterations else None + suggestion_count = len(final_critique.critic_suggestions) if final_critique else 0 + score = quality_proxy_score(suggestion_count) + all_results.append( + { + "status": "success", + **variant.as_dict(), + "run_id": result.metadata.get("run_id"), + "output_path": result.image_path, + "iterations_used": len(result.iterations), + "critic_suggestions": suggestion_count, + "quality_proxy_score": round(score, 2), + "total_seconds": round(variant_seconds, 2), + } + ) + lines.append(f" ok: score={score:.1f}, {variant_seconds:.1f}s") + except Exception as e: + all_results.append( + { + "status": "failed", + **variant.as_dict(), + "error": str(e), + } + ) + lines.append(f" failed: {e}") + + successful_results = [item for item in all_results if item["status"] == "success"] + ranked_results = rank_sweep_results(successful_results) + summary = summarize_sweep(all_results) + report = { + "sweep_id": sweep_id, + "status": "completed", + "input": str(input_file.resolve()), + "caption": caption, + "total_seconds": round(time.perf_counter() - total_start, 2), + "summary": summary, + "results": all_results, + "ranked_results": ranked_results, + "quality_proxy_note": ( + "quality_proxy_score = max(0, 100 - 12.5 * N) where N is critic suggestion " + "count on the final iteration" + ), + } + save_json(report, report_path) + lines.append("") + lines.append(f"Completed: {summary.get('completed', 0)}") + lines.append(f"Failed: {summary.get('failed', 0)}") + lines.append(f"Best variant: {summary.get('best_variant')}") + lines.append(f"Report written: {report_path}") + return "\n".join(lines), str(sweep_dir), str(report_path) + + +def run_orchestrate( + settings: Settings, + *, + paper: Optional[str], + resume_orchestrate: Optional[str], + data_dir: Optional[str], + max_method_figures: int, + max_plot_figures: int, + pdf_pages: Optional[str], + dry_run: bool, + retry_failed: bool, + max_retries: int, + concurrency: int, +) -> tuple[str, str]: + """Run orchestration package workflow. Returns (log, output summary markdown).""" + lines: list[str] = [] + + def _progress(msg: str) -> None: + lines.append(msg) + + try: + result = run_orchestration_package( + paper=paper, + resume_orchestrate=resume_orchestrate, + output_dir=Path(settings.output_dir), + data_dir=data_dir, + max_method_figures=max_method_figures, + max_plot_figures=max_plot_figures, + pdf_pages=pdf_pages, + dry_run=dry_run, + config=None, + vlm_provider=settings.vlm_provider, + vlm_model=settings.vlm_model, + image_provider=settings.image_provider, + image_model=settings.image_model, + iterations=settings.refinement_iterations, + auto=settings.auto_refine, + max_iterations=settings.max_iterations, + optimize=settings.optimize_inputs, + format=settings.output_format, + save_prompts=settings.save_prompts, + venue=None, + retry_failed=retry_failed, + max_retries=max(0, int(max_retries)), + concurrency=max(1, int(concurrency)), + progress_callback=_progress, + ) + except Exception as e: + lines.append(f"{type(e).__name__}: {e}") + lines.append(traceback.format_exc()) + return "\n".join(lines), f"FAILED: {e}" + + lines.insert(0, f"Orchestration ID: {result.get('orchestration_id', '?')}") + lines.insert(1, f"Output: {result.get('orchestrate_dir', '')}") + if result.get("dry_run"): + summary = ( + "## Orchestration Dry Run\n\n" + f"- Orchestration dir: `{result.get('orchestrate_dir', '')}`\n" + f"- Plan: `{result.get('orchestration_plan_path', '')}`\n" + f"- Planned methodology figures: {result.get('methodology_items_planned', 0)}\n" + f"- Planned plot figures: {result.get('plot_items_planned', 0)}\n" + ) + return "\n".join(lines), summary + + summary_lines = [ + "## Orchestration Complete", + "", + f"- Orchestration dir: `{result.get('orchestrate_dir', '')}`", + f"- Plan: `{result.get('orchestration_plan_path', '')}`", + f"- Figure package: `{result.get('figure_package_path', '')}`", + f"- TeX: `{result.get('figures_tex_path', '')}`", + f"- Captions: `{result.get('captions_md_path', '')}`", + f"- Generated: {result.get('generated_count', 0)}", + f"- Failed: {result.get('failed_count', 0)}", + ] + failures = result.get("failures") or [] + if failures: + summary_lines.append("") + summary_lines.append("### Failures") + for item in failures[:10]: + fid = item.get("id") or item.get("item_id") or "unknown" + ferr = item.get("error") or "unknown error" + summary_lines.append(f"- `{fid}`: {ferr}") + return "\n".join(lines), "\n".join(summary_lines) + + def _sanitize_output_filename(name: str) -> str: """Strip directory components and reject traversal attempts.""" cleaned = (name or "").strip() or "composite.png"