Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ OPENAI_API_KEY=
OPENAI_BASE_URL=https://api.openai.com/v1
OPENAI_VLM_MODEL=gpt-5.2
OPENAI_IMAGE_MODEL=gpt-image-1.5
# For gpt-image-2, use 4k/high for final assets or 2k/auto for faster iteration.
# OUTPUT_RESOLUTION=4k
# IMAGE_QUALITY=high
# For Azure OpenAI / Foundry, point to your endpoint:
# OPENAI_BASE_URL=https://<resource>.openai.azure.com/openai/v1

Expand Down
1 change: 1 addition & 0 deletions configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ vlm:
image:
provider: google_imagen
model: gemini-3-pro-image-preview
quality: auto

# Pipeline settings
pipeline:
Expand Down
59 changes: 47 additions & 12 deletions paperbanana/agents/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,15 @@ def __init__(
prompt_dir: str = "prompts",
output_dir: str = "outputs",
prompt_recorder=None,
output_resolution: str = "2k",
image_quality: str = "auto",
):
super().__init__(vlm_provider, prompt_dir, prompt_recorder=prompt_recorder)
self.image_gen = image_gen
self.output_dir = Path(output_dir)
self._last_vector_paths: dict[str, str] = {}
self.output_resolution = output_resolution
self.image_quality = image_quality

@property
def agent_name(self) -> str:
Expand Down Expand Up @@ -104,14 +108,18 @@ async def _generate_diagram(
logger.info("Generating diagram image", iteration=iteration)

# Determine dimensions from aspect ratio or use defaults
w, h = self._ratio_to_dimensions(aspect_ratio) if aspect_ratio else (1792, 1024)
w, h = self._ratio_to_dimensions(
aspect_ratio or "16:9",
output_resolution=self.output_resolution,
)

image = await self.image_gen.generate(
prompt=prompt,
width=w,
height=h,
seed=seed,
aspect_ratio=aspect_ratio,
quality=self.image_quality,
)

if output_path is None:
Expand All @@ -122,19 +130,46 @@ async def _generate_diagram(
return output_path

@staticmethod
def _ratio_to_dimensions(ratio: str) -> tuple[int, int]:
def _ratio_to_dimensions(ratio: str, output_resolution: str = "2k") -> tuple[int, int]:
"""Convert aspect ratio string to pixel dimensions."""
mapping = {
"21:9": (2016, 864),
"16:9": (1792, 1024),
"4:3": (1365, 1024),
"3:2": (1536, 1024),
"1:1": (1024, 1024),
"2:3": (1024, 1536),
"3:4": (1024, 1365),
"9:16": (1024, 1792),
ratios = {
"21:9": (21, 9),
"16:9": (16, 9),
"4:3": (4, 3),
"3:2": (3, 2),
"1:1": (1, 1),
"2:3": (2, 3),
"3:4": (3, 4),
"9:16": (9, 16),
}
return mapping.get(ratio, (1792, 1024))
rw, rh = ratios.get(ratio, (16, 9))
resolution = str(output_resolution).lower()
long_edge = {"1k": 1536, "2k": 2048, "4k": 3840}.get(resolution, 2048)
if ratio == "1:1" and resolution == "1k":
long_edge = 1024

if rw >= rh:
width = long_edge
height = round(long_edge * rh / rw)
else:
height = long_edge
width = round(long_edge * rw / rh)

width = max(16, round(width / 16) * 16)
height = max(16, round(height / 16) * 16)

max_pixels = 8_294_400
if width * height > max_pixels:
scale = (max_pixels / (width * height)) ** 0.5
width = max(16, round((width * scale) / 16) * 16)
height = max(16, round((height * scale) / 16) * 16)
while width * height > max_pixels:
if width >= height:
width -= 16
else:
height -= 16

return width, height

async def _generate_plot(
self,
Expand Down
27 changes: 26 additions & 1 deletion paperbanana/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pydantic_settings import BaseSettings

OutputFormat = Literal["png", "jpeg", "webp"]
ImageQuality = Literal["low", "medium", "high", "auto"]
ExemplarRetrievalMode = Literal["external_only", "external_then_rerank"]
Venue = Literal["neurips", "icml", "acl", "ieee", "custom"]
VectorExportMode = Literal["none", "svg", "pdf", "both"]
Expand Down Expand Up @@ -70,7 +71,8 @@ class Settings(BaseSettings):
auto_refine: bool = False
max_iterations: int = 30
optimize_inputs: bool = False
output_resolution: str = "2k"
output_resolution: str = Field(default="2k", alias="OUTPUT_RESOLUTION")
image_quality: ImageQuality = Field(default="auto", alias="IMAGE_QUALITY")
seed: Optional[int] = None
exemplar_retrieval_enabled: bool = False
exemplar_retrieval_endpoint: Optional[str] = None
Expand Down Expand Up @@ -178,6 +180,28 @@ def validate_output_format(cls, v: Any) -> str:
raise ValueError(f"output_format must be png, jpeg, or webp. Got: {v}")
return v

@field_validator("output_resolution", mode="before")
@classmethod
def validate_output_resolution(cls, v: Any) -> str:
"""Validate output_resolution is 1k, 2k, or 4k."""
if v is None:
return "2k"
v = str(v).lower()
if v not in ("1k", "2k", "4k"):
raise ValueError(f"output_resolution must be 1k, 2k, or 4k. Got: {v}")
return v

@field_validator("image_quality", mode="before")
@classmethod
def validate_image_quality(cls, v: Any) -> str:
"""Validate image_quality is low, medium, high, or auto."""
if v is None:
return "auto"
v = str(v).lower()
if v not in ("low", "medium", "high", "auto"):
raise ValueError(f"image_quality must be low, medium, high, or auto. Got: {v}")
return v

@field_validator("exemplar_retrieval_top_k")
@classmethod
def validate_exemplar_retrieval_top_k(cls, v: int) -> int:
Expand Down Expand Up @@ -254,6 +278,7 @@ def _flatten_yaml(config: dict, prefix: str = "") -> dict:
"vlm.model": "vlm_model",
"image.provider": "image_provider",
"image.model": "image_model",
"image.quality": "image_quality",
"pipeline.num_retrieval_examples": "num_retrieval_examples",
"pipeline.refinement_iterations": "refinement_iterations",
"pipeline.auto_refine": "auto_refine",
Expand Down
2 changes: 2 additions & 0 deletions paperbanana/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@ def __init__(
prompt_dir=prompt_dir,
output_dir=str(self._run_dir),
prompt_recorder=self._prompt_recorder,
output_resolution=self.settings.output_resolution,
image_quality=self.settings.image_quality,
)
self.critic = CriticAgent(
self._vlm, prompt_dir=prompt_dir, prompt_recorder=self._prompt_recorder
Expand Down
5 changes: 4 additions & 1 deletion paperbanana/core/pricing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Pricing tables for VLM and image generation providers.

Prices are in USD. VLM prices are per 1K tokens. Image prices are per image.
Last updated: 2026-03-18.
Last updated: 2026-05-01.
"""

from __future__ import annotations
Expand All @@ -23,6 +23,7 @@
("gemini", "gemini-2.5-pro"): {"input_per_1k": 0.00125, "output_per_1k": 0.01},
("gemini", "gemini-3-pro"): {"input_per_1k": 0.00125, "output_per_1k": 0.005},
# OpenAI
("openai", "gpt-5.5"): {"input_per_1k": 0.005, "output_per_1k": 0.03},
("openai", "gpt-5.2"): {"input_per_1k": 0.0025, "output_per_1k": 0.01},
("openai", "gpt-5.1"): {"input_per_1k": 0.002, "output_per_1k": 0.008},
("openai", "gpt-4o"): {"input_per_1k": 0.0025, "output_per_1k": 0.01},
Expand All @@ -48,7 +49,9 @@
# Google Imagen — free tier
("google_imagen", "gemini-3-pro-image-preview"): 0.0,
# OpenAI
# gpt-image-2 is token-priced; this is a high-quality square-image estimate.
("openai_imagen", "gpt-image-1.5"): 0.02,
("openai_imagen", "gpt-image-2"): 0.211,
("openai_imagen", "gpt-image-1"): 0.04,
("openai_imagen", "dall-e-3"): 0.04,
# Bedrock Nova Canvas
Expand Down
2 changes: 2 additions & 0 deletions paperbanana/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ async def generate(
height: int = 1024,
seed: Optional[int] = None,
aspect_ratio: Optional[str] = None,
quality: Optional[str] = None,
) -> Image.Image:
"""Generate an image from a text prompt.

Expand All @@ -113,6 +114,7 @@ async def generate(
seed: Random seed for reproducibility.
aspect_ratio: Target aspect ratio (1:1, 2:3, 3:2, 3:4, 4:3, 9:16, 16:9, 21:9).
takes precedence over width/height for providers that support it.
quality: Optional provider-specific rendering quality.

Returns:
Generated PIL Image.
Expand Down
1 change: 1 addition & 0 deletions paperbanana/providers/image_gen/bedrock_imagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ async def generate(
height: int = 1024,
seed: Optional[int] = None,
aspect_ratio: Optional[str] = None,
quality: Optional[str] = None,
) -> Image.Image:
client = self._get_client()

Expand Down
1 change: 1 addition & 0 deletions paperbanana/providers/image_gen/google_imagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ async def generate(
height: int = 1024,
seed: Optional[int] = None,
aspect_ratio: Optional[str] = None,
quality: Optional[str] = None,
) -> Image.Image:
from google.genai import types

Expand Down
32 changes: 25 additions & 7 deletions paperbanana/providers/image_gen/openai_imagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
logger = structlog.get_logger()


def _is_gpt_image_2(model: str) -> bool:
return model.lower() == "gpt-image-2"


class OpenAIImageGen(ImageGenProvider):
"""Image generation using the OpenAI Python SDK (async).

Expand Down Expand Up @@ -62,11 +66,15 @@ def is_available(self) -> bool:

@property
def supported_ratios(self) -> list[str]:
# OpenAI only has 3 native sizes: 1024x1024, 1536x1024, 1024x1536
if _is_gpt_image_2(self._model):
return ["1:1", "2:3", "3:2", "3:4", "4:3", "9:16", "16:9", "21:9"]
# Earlier GPT Image models only have 3 native sizes.
return ["1:1", "3:2", "2:3"]

def _size_string(self, width: int, height: int) -> str:
"""Map pixel dimensions to an OpenAI-supported size string."""
if _is_gpt_image_2(self._model):
return f"{width}x{height}"
ratio = width / height
if ratio > 1.2:
return "1536x1024"
Expand Down Expand Up @@ -96,19 +104,29 @@ async def generate(
height: int = 1024,
seed: Optional[int] = None,
aspect_ratio: Optional[str] = None,
quality: Optional[str] = None,
) -> Image.Image:
client = self._get_client()

full_prompt = prompt
if negative_prompt:
full_prompt += f"\n\nAvoid: {negative_prompt}"

result = await client.images.generate(
model=self._model,
prompt=full_prompt,
n=1,
size=self._RATIO_TO_SIZE.get(aspect_ratio, self._size_string(width, height)),
)
if _is_gpt_image_2(self._model):
size = self._size_string(width, height)
else:
size = self._RATIO_TO_SIZE.get(aspect_ratio, self._size_string(width, height))

kwargs = {
"model": self._model,
"prompt": full_prompt,
"n": 1,
"size": size,
}
if quality:
kwargs["quality"] = quality

result = await client.images.generate(**kwargs)

b64_data = result.data[0].b64_json
image_bytes = base64.b64decode(b64_data)
Expand Down
1 change: 1 addition & 0 deletions paperbanana/providers/image_gen/openrouter_imagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ async def generate(
height: int = 1024,
seed: Optional[int] = None,
aspect_ratio: Optional[str] = None,
quality: Optional[str] = None,
) -> Image.Image:
client = self._get_client()

Expand Down
8 changes: 7 additions & 1 deletion paperbanana/providers/vlm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
logger = structlog.get_logger()


def _uses_fixed_temperature(model: str) -> bool:
"""Return true when the model only accepts the API default temperature."""
return model.lower().startswith("gpt-5")


class OpenAIVLM(VLMProvider):
"""VLM provider using the OpenAI Python SDK (async).

Expand Down Expand Up @@ -99,8 +104,9 @@ async def generate(
kwargs = {
"model": self._model,
"messages": messages,
"temperature": temperature,
}
if not _uses_fixed_temperature(self._model):
kwargs["temperature"] = temperature

if response_format == "json" and self._json_mode:
kwargs["response_format"] = {"type": "json_object"}
Expand Down
8 changes: 8 additions & 0 deletions tests/test_agents/test_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,11 @@ def test_last_vector_paths_reset_on_each_run_call(tmp_path):
output_path2 = str(tmp_path / "plot2.png")
agent._execute_plot_code(_SIMPLE_PLOT_CODE, output_path2)
assert agent._last_vector_paths == {}


def test_ratio_to_dimensions_supports_high_resolution_landscape():
assert VisualizerAgent._ratio_to_dimensions("16:9", output_resolution="4k") == (3840, 2160)


def test_ratio_to_dimensions_supports_2k_landscape():
assert VisualizerAgent._ratio_to_dimensions("16:9", output_resolution="2k") == (2048, 1152)
26 changes: 26 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,32 @@ def test_output_format_from_yaml_invalid():
Path(path).unlink(missing_ok=True)


def test_image_generation_options_from_yaml():
"""Image resolution and quality load from YAML config."""
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
yaml.safe_dump({"pipeline": {"output_resolution": "4k"}, "image": {"quality": "high"}}, f)
path = f.name

try:
settings = Settings.from_yaml(path)
assert settings.output_resolution == "4k"
assert settings.image_quality == "high"
finally:
Path(path).unlink(missing_ok=True)


def test_output_resolution_invalid_rejected():
"""Invalid output_resolution is rejected with clear error."""
with pytest.raises(ValidationError, match="output_resolution must be 1k, 2k, or 4k"):
Settings(output_resolution="8k")


def test_image_quality_invalid_rejected():
"""Invalid image_quality is rejected with clear error."""
with pytest.raises(ValidationError, match="image_quality must be low, medium, high, or auto"):
Settings(image_quality="ultra")


def test_exemplar_retrieval_top_k_must_be_positive():
"""exemplar_retrieval_top_k must be >= 1."""
with pytest.raises(ValidationError, match="exemplar_retrieval_top_k must be >= 1"):
Expand Down
10 changes: 10 additions & 0 deletions tests/test_cost_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ def test_prefix_match_vlm(self):
assert result is not None
assert result["input_per_1k"] > 0

def test_openai_gpt_5_5_pricing(self):
result = lookup_vlm_price("openai", "gpt-5.5")
assert result is not None
assert result["input_per_1k"] == pytest.approx(0.005)
assert result["output_per_1k"] == pytest.approx(0.03)

def test_unknown_vlm_returns_none(self):
result = lookup_vlm_price("unknown_provider", "unknown_model")
assert result is None
Expand All @@ -38,6 +44,10 @@ def test_openai_image_pricing(self):
assert result is not None
assert result > 0

def test_openai_gpt_image_2_pricing(self):
result = lookup_image_price("openai_imagen", "gpt-image-2")
assert result == pytest.approx(0.211)

def test_unknown_image_returns_none(self):
result = lookup_image_price("unknown", "unknown")
assert result is None
Expand Down
Loading