Skip to content
164 changes: 164 additions & 0 deletions examples/tikz_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
"""End-to-end TikZ export example.

Demonstrates three usage patterns:

1. Generate a methodology diagram and export TikZ in a single pipeline run.
2. Export TikZ from an already-generated image (post-hoc, via the `tikz` CLI
command or the TikZExporterAgent API directly).
3. Export PGFPlots from a generated statistical plot.

Run any section independently — each is guarded by a ``# --- `` comment.
"""

from __future__ import annotations

import asyncio
from pathlib import Path

# ---------------------------------------------------------------------------
# Pattern 1: Generate + export TikZ in one pipeline run
# ---------------------------------------------------------------------------
# CLI equivalent:
# paperbanana generate \
# --input examples/sample_inputs/transformer_method.txt \
# --caption "Overview of our Transformer encoder architecture" \
# --export-tikz
#
# The pipeline saves:
# outputs/<run_id>/final_output.png ← raster image
# outputs/<run_id>/final_output.tex ← TikZ source (new)


async def generate_with_tikz_export(
input_path: str = "examples/sample_inputs/transformer_method.txt",
caption: str = "Overview of our Transformer encoder architecture",
) -> None:
"""Run the full pipeline and export TikZ alongside the PNG."""
from dotenv import load_dotenv

load_dotenv()

from paperbanana.core.config import Settings
from paperbanana.core.pipeline import PaperBananaPipeline
from paperbanana.core.types import DiagramType, GenerationInput

settings = Settings(export_tikz=True)

gen_input = GenerationInput(
source_context=Path(input_path).read_text(encoding="utf-8"),
communicative_intent=caption,
diagram_type=DiagramType.METHODOLOGY,
)

pipeline = PaperBananaPipeline(settings=settings)
result = await pipeline.generate(gen_input)

print(f"Image : {result.image_path}")
print(f"TikZ : {result.tikz_path}")


# ---------------------------------------------------------------------------
# Pattern 2: Post-hoc export from an existing image
# ---------------------------------------------------------------------------
# CLI equivalent:
# paperbanana tikz \
# --input outputs/my_run/final_output.png \
# --source-context examples/sample_inputs/transformer_method.txt \
# --caption "Overview of our Transformer encoder architecture"


async def export_existing_image(
image_path: str,
source_context_path: str = "",
caption: str = "",
output_tex: str = "",
) -> str:
"""Convert an existing generated image to TikZ source.

Args:
image_path: Path to the PNG/JPEG image to convert.
source_context_path: Optional path to the methodology text file.
caption: Optional figure caption for context.
output_tex: Where to write the .tex file (defaults to same dir as image).

Returns:
Path to the saved .tex file.
"""
from dotenv import load_dotenv

load_dotenv()

from paperbanana.agents.tikz_exporter import TikZExporterAgent
from paperbanana.core.config import Settings
from paperbanana.core.types import DiagramType
from paperbanana.providers.registry import ProviderRegistry

settings = Settings()
vlm = ProviderRegistry.create_vlm(settings)
agent = TikZExporterAgent(vlm)

context_text = ""
if source_context_path:
context_text = Path(source_context_path).read_text(encoding="utf-8")

tikz_source = await agent.run(
image_path=image_path,
source_context=context_text,
caption=caption,
diagram_type=DiagramType.METHODOLOGY,
venue=settings.venue,
)

tex_path = Path(output_tex) if output_tex else Path(image_path).with_suffix(".tex")
tex_path.write_text(tikz_source, encoding="utf-8")
print(f"TikZ source saved to: {tex_path}")
return str(tex_path)


# ---------------------------------------------------------------------------
# Pattern 3: Generate a statistical plot and export PGFPlots
# ---------------------------------------------------------------------------
# CLI equivalent:
# paperbanana plot \
# --data examples/sample_data/benchmark_slice.csv \
# --intent "Accuracy vs. model size across four baselines" \
# --export-pgfplots


async def generate_plot_with_pgfplots(
data_path: str = "examples/sample_data/benchmark_slice.csv",
intent: str = "Accuracy vs. model size across four baselines",
) -> None:
"""Run the plot pipeline and export PGFPlots markup alongside the PNG."""
from dotenv import load_dotenv

load_dotenv()

from paperbanana.core.config import Settings
from paperbanana.core.pipeline import PaperBananaPipeline
from paperbanana.core.plot_data import load_statistical_plot_payload
from paperbanana.core.types import DiagramType, GenerationInput

source_context, raw_data = load_statistical_plot_payload(Path(data_path))

settings = Settings(export_pgfplots=True)

gen_input = GenerationInput(
source_context=source_context,
communicative_intent=intent,
diagram_type=DiagramType.STATISTICAL_PLOT,
raw_data={"data": raw_data},
)

pipeline = PaperBananaPipeline(settings=settings)
result = await pipeline.generate(gen_input)

print(f"Plot : {result.image_path}")
print(f"PGFPlots : {result.tikz_path}")


# ---------------------------------------------------------------------------
# Main — run pattern 1 as a quick smoke-test
# ---------------------------------------------------------------------------
if __name__ == "__main__":
asyncio.run(generate_with_tikz_export())
2 changes: 2 additions & 0 deletions paperbanana/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from paperbanana.agents.planner import PlannerAgent
from paperbanana.agents.retriever import RetrieverAgent
from paperbanana.agents.stylist import StylistAgent
from paperbanana.agents.tikz_exporter import TikZExporterAgent
from paperbanana.agents.visualizer import VisualizerAgent

__all__ = [
Expand All @@ -18,4 +19,5 @@
"StylistAgent",
"VisualizerAgent",
"CriticAgent",
"TikZExporterAgent",
]
134 changes: 134 additions & 0 deletions paperbanana/agents/tikz_exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""TikZ Exporter Agent: Converts generated images to LaTeX/TikZ source code."""

from __future__ import annotations

import re
from typing import Optional

import structlog

from paperbanana.agents.base import BaseAgent
from paperbanana.core.types import DiagramType
from paperbanana.core.utils import load_image
from paperbanana.providers.base import VLMProvider

logger = structlog.get_logger()

# Header comment template embedded in every exported .tex file.
_HEADER_TEMPLATE = """\
% Generated by PaperBanana {version}
% Source: {source_label}
% Model: {model}
% Venue: {venue}
% Diagram type: {diagram_type}
"""


class TikZExporterAgent(BaseAgent):
"""Converts a generated academic illustration to compilable LaTeX/TikZ source.

For methodology diagrams the agent emits a standalone TikZ picture. For
statistical plots it emits PGFPlots markup. The output is a self-contained
snippet that can be pasted directly into any LaTeX document.
"""

def __init__(
self,
vlm_provider: VLMProvider,
prompt_dir: str = "prompts",
prompt_recorder=None,
):
super().__init__(vlm_provider, prompt_dir, prompt_recorder=prompt_recorder)

@property
def agent_name(self) -> str:
return "tikz_exporter"

async def run(
self,
image_path: str,
source_context: str,
caption: str,
diagram_type: DiagramType = DiagramType.METHODOLOGY,
description: Optional[str] = None,
source_label: str = "",
model_label: str = "",
venue: str = "neurips",
version: str = "",
) -> str:
"""Convert a generated image to TikZ/PGFPlots source.

Args:
image_path: Path to the generated image to convert.
source_context: Original methodology text (used as context).
caption: Figure caption / communicative intent.
diagram_type: METHODOLOGY → TikZ; STATISTICAL_PLOT → PGFPlots.
description: Optimised description from the planner/stylist (optional).
source_label: Human-readable label for the header comment.
model_label: VLM/image-gen model label for the header comment.
venue: Target venue (neurips, icml, …).
version: PaperBanana version string for the header comment.

Returns:
A string containing the complete LaTeX snippet (with header comment).
"""
image = load_image(image_path)

prompt_type = "diagram" if diagram_type == DiagramType.METHODOLOGY else "plot"
template = self.load_prompt(prompt_type)

desc_block = description.strip() if description else "(not available)"
prompt = self.format_prompt(
template,
prompt_label="tikz_exporter",
source_context=source_context,
caption=caption,
description=desc_block,
)

logger.info(
"Running TikZ exporter",
image_path=image_path,
diagram_type=diagram_type.value,
)

response = await self.vlm.generate(
prompt=prompt,
images=[image],
temperature=0.2,
max_tokens=8192,
)

tikz_code = self._extract_code(response)

header = _HEADER_TEMPLATE.format(
version=version or "unknown",
source_label=source_label or image_path,
model=model_label or getattr(self.vlm, "model_name", "unknown"),
venue=venue,
diagram_type=diagram_type.value,
)

full_output = header + tikz_code
logger.info("TikZ export complete", output_length=len(full_output))
return full_output

@staticmethod
def _extract_code(response: str) -> str:
"""Strip markdown code fences and return the raw LaTeX/TikZ snippet."""
# Try to extract content from a fenced block (```latex, ```tikz, ```)
fenced = re.search(
r"```(?:latex|tikz|pgfplots|tex)?\s*\n(.*?)```",
response,
re.DOTALL | re.IGNORECASE,
)
if fenced:
return fenced.group(1).strip()

# If the response already looks like raw LaTeX (starts with % or \begin)
stripped = response.strip()
if stripped.startswith(("\\begin", "%", "\\tikz")):
return stripped

# Best-effort: return the full response trimmed
return stripped
Loading
Loading