Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions paperbanana/agents/tikz_exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""TikZ/PGFPlots exporter agent — converts generated images to LaTeX source."""

from __future__ import annotations

import re
from typing import Optional

import structlog

from paperbanana.agents.base import BaseAgent
from paperbanana.core.types import DiagramType
from paperbanana.core.utils import load_image
from paperbanana.providers.base import VLMProvider

logger = structlog.get_logger()

_HEADER_TEMPLATE = """\
% Generated by PaperBanana {version}
% Source: {source_label}
% Model: {model}
% Venue: {venue}
% Diagram type: {diagram_type}
"""


class TikZExporterAgent(BaseAgent):
"""Converts a generated academic illustration to compilable LaTeX/TikZ source.

For methodology diagrams the agent emits a standalone TikZ picture.
For statistical plots it emits PGFPlots markup.
"""

def __init__(
self,
vlm_provider: VLMProvider,
prompt_dir: str = "prompts",
prompt_recorder=None,
):
super().__init__(vlm_provider, prompt_dir, prompt_recorder=prompt_recorder)

@property
def agent_name(self) -> str:
return "tikz_exporter"

async def run(
self,
image_path: str,
source_context: str,
caption: str,
diagram_type: DiagramType = DiagramType.METHODOLOGY,
description: Optional[str] = None,
source_label: str = "",
model_label: str = "",
venue: str = "neurips",
version: str = "",
) -> str:
"""Convert a generated image to TikZ/PGFPlots source.

Returns:
A string containing the complete LaTeX snippet with metadata header.
"""
image = load_image(image_path)

prompt_type = "diagram" if diagram_type == DiagramType.METHODOLOGY else "plot"
template = self.load_prompt(prompt_type)

desc_block = description.strip() if description else "(not available)"
prompt = self.format_prompt(
template,
prompt_label="tikz_exporter",
source_context=source_context,
caption=caption,
description=desc_block,
)

logger.info(
"Running TikZ exporter",
image_path=image_path,
diagram_type=diagram_type.value,
)

response = await self.vlm.generate(
prompt=prompt,
images=[image],
temperature=0.2,
max_tokens=8192,
)

tikz_code = self._extract_code(response)

header = _HEADER_TEMPLATE.format(
version=version or "unknown",
source_label=source_label or image_path,
model=model_label or getattr(self.vlm, "model_name", "unknown"),
venue=venue,
diagram_type=diagram_type.value,
)

full_output = header + tikz_code
logger.info("TikZ export complete", output_length=len(full_output))
return full_output

@staticmethod
def _extract_code(response: str) -> str:
"""Strip markdown code fences and return the raw LaTeX/TikZ snippet."""
fenced = re.search(
r"```(?:latex|tikz|pgfplots|tex)?\s*\n(.*?)```",
response,
re.DOTALL | re.IGNORECASE,
)
if fenced:
return fenced.group(1).strip()

stripped = response.strip()
if stripped.startswith(("\\begin", "%", "\\tikz")):
return stripped

return stripped
160 changes: 160 additions & 0 deletions paperbanana/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,11 @@ def generate(
"--pdf-pages",
help=("PDF input only: 1-based pages (e.g. '1-5', '3', '1-3,7,10-12'); default: all pages"),
),
export_tikz: bool = typer.Option(
False,
"--export-tikz",
help="Export a compilable TikZ/LaTeX source file alongside the generated image",
),
verbose: bool = typer.Option(
False, "--verbose", "-v", help="Show detailed agent progress and timing"
),
Expand Down Expand Up @@ -270,6 +275,8 @@ def generate(
overrides["venue"] = venue
if prompt_dir:
overrides["prompt_dir"] = prompt_dir
if export_tikz:
overrides["export_tikz"] = True

if config:
settings = Settings.from_yaml(config, **overrides)
Expand Down Expand Up @@ -595,6 +602,8 @@ def on_progress(event: PipelineProgressEvent) -> None:
f" · {len(result.iterations)} iterations[/dim]\n"
)
console.print(f" Output: [bold]{result.image_path}[/bold]")
if result.tikz_path:
console.print(f" TikZ: [bold]{result.tikz_path}[/bold]")
console.print(f" Run ID: [dim]{result.metadata.get('run_id', 'unknown')}[/dim]")

cost_data = result.metadata.get("cost")
Expand Down Expand Up @@ -1564,6 +1573,11 @@ def plot(
"--budget",
help="Budget cap in USD; pipeline aborts gracefully when exceeded",
),
export_pgfplots: bool = typer.Option(
False,
"--export-pgfplots",
help="Export a compilable PGFPlots/LaTeX source file alongside the generated plot",
),
):
"""Generate a statistical plot from data."""
if format not in ("png", "jpeg", "webp"):
Expand Down Expand Up @@ -1602,6 +1616,7 @@ def plot(
save_prompts=True if save_prompts is None else save_prompts,
venue=venue,
budget_usd=budget,
export_pgfplots=export_pgfplots,
)

gen_input = GenerationInput(
Expand Down Expand Up @@ -1651,6 +1666,8 @@ async def _run():

result = asyncio.run(_run())
console.print(f"\n[green]Done![/green] Plot saved to: [bold]{result.image_path}[/bold]")
if result.tikz_path:
console.print(f" PGFPlots: [bold]{result.tikz_path}[/bold]")

cost_data = result.metadata.get("cost")
if cost_data:
Expand Down Expand Up @@ -2466,5 +2483,148 @@ def studio(
)


@app.command()
def tikz(
input: str = typer.Option(
...,
"--input",
"-i",
help="Path to an existing generated image to convert to LaTeX/TikZ",
),
output: Optional[str] = typer.Option(
None,
"--output",
"-o",
help="Output .tex file path (default: same directory as input, with .tex extension)",
),
source_context: Optional[str] = typer.Option(
None,
"--source-context",
help="Path to a methodology text file for context (improves TikZ fidelity)",
),
caption: str = typer.Option(
"",
"--caption",
"-c",
help="Figure caption / communicative intent (improves TikZ fidelity)",
),
diagram_type: str = typer.Option(
"diagram",
"--diagram-type",
help="Type of illustration: diagram (TikZ) or plot (PGFPlots)",
),
vlm_provider: Optional[str] = typer.Option(
None, "--vlm-provider", help="VLM provider (gemini, openai, anthropic, ...)"
),
vlm_model: Optional[str] = typer.Option(
None, "--vlm-model", help="VLM model name override"
),
venue: Optional[str] = typer.Option(
None,
"--venue",
help="Target venue style (neurips, icml, acl, ieee, custom)",
),
config: Optional[str] = typer.Option(
None, "--config", help="Path to a YAML config file"
),
verbose: bool = typer.Option(
False, "--verbose", "-v", help="Show detailed progress"
),
):
"""Convert an existing generated image to a compilable LaTeX/TikZ source file."""
input_path = Path(input)
if not input_path.exists():
console.print(f"[red]Error: Input file not found: {input}[/red]")
raise typer.Exit(1)

if diagram_type not in ("diagram", "plot"):
console.print("[red]Error: --diagram-type must be 'diagram' or 'plot'[/red]")
raise typer.Exit(1)

if venue and venue.lower() not in ("neurips", "icml", "acl", "ieee", "custom"):
console.print(
f"[red]Error: --venue must be neurips, icml, acl, ieee, or custom. Got: {venue}[/red]"
)
raise typer.Exit(1)

configure_logging(verbose=verbose)

tex_path = Path(output) if output else input_path.with_suffix(".tex")

context_text = ""
if source_context:
sc_path = Path(source_context)
if not sc_path.exists():
console.print(
f"[red]Error: Source context file not found: {source_context}[/red]"
)
raise typer.Exit(1)
context_text = sc_path.read_text(encoding="utf-8")

overrides: dict = {}
if vlm_provider:
overrides["vlm_provider"] = vlm_provider
if vlm_model:
overrides["vlm_model"] = vlm_model
if venue:
overrides["venue"] = venue

if config:
settings = Settings.from_yaml(config, **overrides)
else:
from dotenv import load_dotenv

load_dotenv()
settings = Settings(**overrides)

from paperbanana.agents.tikz_exporter import TikZExporterAgent
from paperbanana.providers.registry import ProviderRegistry

dtype = (
DiagramType.METHODOLOGY
if diagram_type == "diagram"
else DiagramType.STATISTICAL_PLOT
)

console.print(
Panel.fit(
f"[bold]PaperBanana[/bold] — Export to LaTeX/TikZ\n\n"
f"Input: {input_path}\n"
f"Output: {tex_path}\n"
f"Type: {'TikZ (methodology)' if dtype == DiagramType.METHODOLOGY else 'PGFPlots'}\n"
f"VLM: {settings.vlm_provider} / {settings.effective_vlm_model}",
border_style="blue",
)
)

async def _run():
vlm = ProviderRegistry.create_vlm(settings)
agent = TikZExporterAgent(vlm)
return await agent.run(
image_path=str(input_path),
source_context=context_text,
caption=caption,
diagram_type=dtype,
venue=settings.venue,
)

console.print()
console.print(" [dim]●[/dim] Generating TikZ source...", end="")
try:
tikz_source = asyncio.run(_run())
except Exception as e:
console.print(f" [red]✗[/red]\n[red]Error: {e}[/red]")
raise typer.Exit(1)

console.print(" [green]✓[/green]")

tex_path.parent.mkdir(parents=True, exist_ok=True)
tex_path.write_text(tikz_source, encoding="utf-8")

console.print(
f"\n[green]Done![/green] LaTeX source saved to: [bold]{tex_path}[/bold]"
)


if __name__ == "__main__":
app()
4 changes: 4 additions & 0 deletions paperbanana/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ class Settings(BaseSettings):
output_format: OutputFormat = "png"
save_iterations: bool = True
save_prompts: bool = True
export_tikz: bool = False
export_pgfplots: bool = False

# Prompt settings
prompt_dir: Optional[str] = None
Expand Down Expand Up @@ -242,6 +244,8 @@ def _flatten_yaml(config: dict, prefix: str = "") -> dict:
"output.format": "output_format",
"output.save_iterations": "save_iterations",
"output.save_prompts": "save_prompts",
"output.export_tikz": "export_tikz",
"output.export_pgfplots": "export_pgfplots",
"cost.budget": "budget_usd",
"pipeline.prompt_dir": "prompt_dir",
}
Expand Down
Loading
Loading