Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 60 additions & 11 deletions app/images/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import logging
import os
import uuid
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import TimeoutError as FutureTimeoutError
from typing import Callable, List, Sequence
from urllib.parse import urlparse

Expand Down Expand Up @@ -90,51 +92,98 @@ def describe_scene(self, theme_body: str, tag_names: Sequence[str]) -> str:

# ---------- Replicate image generator ----------

DEFAULT_FLUX_MODEL = "black-forest-labs/flux-schnell"
# The canonical "black-forest-labs/flux-schnell" model is routed through a shared
# pool that has periods where predictions accept-but-never-start. The "-lora" sibling
# runs the same Flux Schnell weights on a different pool; calling it without LoRA
# inputs produces equivalent images. It does, however, OOM on num_outputs>1, so we
# fan out one prediction per candidate.
DEFAULT_FLUX_MODEL = "black-forest-labs/flux-schnell-lora"
DEFAULT_PREDICTION_TIMEOUT_SECONDS = 60.0


class ReplicateImageGenerator:
"""Calls Flux Schnell on Replicate; returns temporary candidate URLs."""
"""Generates candidates via parallel Flux Schnell predictions.

Each candidate is its own prediction with ``num_outputs=1``. This isolates failures
(a single CUDA OOM or stuck worker costs one candidate, not the whole batch) and
lets us put a wall-clock timeout around each call so the request can never hang.
"""

def __init__(
self,
model: str = DEFAULT_FLUX_MODEL,
aspect_ratio: str = "1:1",
num_outputs: int = 4,
timeout_seconds: float = DEFAULT_PREDICTION_TIMEOUT_SECONDS,
runner: Callable[..., list] = replicate.run,
) -> None:
self._model = model
self._aspect_ratio = aspect_ratio
self._num_outputs = num_outputs
self._timeout_seconds = timeout_seconds
self._runner = runner

def generate(self, prompt: str) -> List[str]:
if not os.getenv("REPLICATE_API_TOKEN"):
raise ImageGenerationError("REPLICATE_API_TOKEN is not set")

pool = ThreadPoolExecutor(max_workers=self._num_outputs)
try:
futures = [
pool.submit(self._generate_one, prompt)
for _ in range(self._num_outputs)
]
urls: List[str] = []
errors: List[str] = []
for future in futures:
try:
urls.append(future.result(timeout=self._timeout_seconds))
except FutureTimeoutError:
errors.append(f"timed out after {self._timeout_seconds:.0f}s")
except ImageGenerationError as e:
errors.append(str(e))
finally:
# Don't block returning to the user on slow stragglers; they'll finish
# in their own time and the threadpool is GC'd.
pool.shutdown(wait=False, cancel_futures=True)

if not urls:
logger.error(
"All %d Replicate candidates failed (prompt_len=%d): %s",
self._num_outputs, len(prompt), errors,
)
raise ImageGenerationError(
f"All {self._num_outputs} candidates failed: {'; '.join(errors)}"
)
if errors:
logger.warning(
"Partial Replicate failure: %d/%d candidates failed: %s",
len(errors), self._num_outputs, errors,
)
return urls

def _generate_one(self, prompt: str) -> str:
try:
output = self._runner(
self._model,
input={
"prompt": prompt,
"num_outputs": self._num_outputs,
"num_outputs": 1,
"aspect_ratio": self._aspect_ratio,
"output_format": "webp",
"output_quality": 90,
},
)
except replicate.exceptions.ReplicateError as e:
logger.error("Replicate API error: %s (prompt_len=%d)", e, len(prompt))
raise ImageGenerationError(f"Replicate error: {e}") from e
except httpx.HTTPError as e:
logger.error("Replicate network error: %s", e)
raise ImageGenerationError(f"Network error contacting Replicate: {e}") from e
raise ImageGenerationError(f"Network error: {e}") from e

urls = [str(item.url) if hasattr(item, "url") else str(item) for item in output]
if not urls:
logger.error("Replicate returned empty output (prompt_len=%d)", len(prompt))
raise ImageGenerationError("Replicate returned no candidates")
return urls
items = list(output)
if not items:
raise ImageGenerationError("Replicate returned empty output")
item = items[0]
return str(item.url) if hasattr(item, "url") else str(item)


# ---------- R2 image store ----------
Expand Down
4 changes: 3 additions & 1 deletion app/images/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
"Rendered as a flat oil painting in a limited three-color palette, "
"figurative and confident, contemporary museum-quality painting, "
"tone balanced between gravity and play. "
"If human figures are included, represent varied skin tones, ages, and body types."
"If human figures are included, represent varied skin tones, ages, and body types. "
"When hands are visible, they rest calmly at the sides, are folded, or hold "
"a single object — never reaching, gesturing, or pointing."
)


Expand Down
71 changes: 57 additions & 14 deletions tests/test_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,50 +245,93 @@ def test_replicate_generator_missing_token(monkeypatch):
generator.generate("prompt")


def test_replicate_generator_returns_urls(monkeypatch):
def test_replicate_generator_fans_out_one_call_per_candidate(monkeypatch):
"""Each candidate is its own prediction with num_outputs=1 — avoids GPU OOM."""
monkeypatch.setenv("REPLICATE_API_TOKEN", "fake")
fake_runner = MagicMock(return_value=["https://r/1.webp", "https://r/2.webp"])
generator = ReplicateImageGenerator(runner=fake_runner)
counter = {"n": 0}

def fake_runner(model, input):
counter["n"] += 1
return [f"https://r/{counter['n']}.webp"]

generator = ReplicateImageGenerator(runner=fake_runner, num_outputs=4)
urls = generator.generate("a scene")

assert urls == ["https://r/1.webp", "https://r/2.webp"]
assert len(urls) == 4
assert counter["n"] == 4
assert set(urls) == {f"https://r/{i}.webp" for i in range(1, 5)}


def test_replicate_generator_sends_correct_input_shape(monkeypatch):
monkeypatch.setenv("REPLICATE_API_TOKEN", "fake")
fake_runner = MagicMock(return_value=["https://r/1.webp"])
generator = ReplicateImageGenerator(runner=fake_runner, num_outputs=1)

generator.generate("a scene")

call_input = fake_runner.call_args.kwargs["input"]
assert call_input["prompt"] == "a scene"
assert call_input["num_outputs"] == 4
assert call_input["num_outputs"] == 1
assert call_input["aspect_ratio"] == "1:1"
assert call_input["output_format"] == "webp"


def test_replicate_generator_handles_fileoutput_objects(monkeypatch):
monkeypatch.setenv("REPLICATE_API_TOKEN", "fake")
item = MagicMock()
item.url = "https://r/1.webp"
fake_runner = MagicMock(return_value=[item])
generator = ReplicateImageGenerator(runner=fake_runner)
generator = ReplicateImageGenerator(runner=fake_runner, num_outputs=1)

assert generator.generate("prompt") == ["https://r/1.webp"]


def test_replicate_generator_rejects_empty_output(monkeypatch):
def test_replicate_generator_returns_partial_results_when_some_calls_fail(monkeypatch):
"""One bad candidate shouldn't fail the whole batch — degraded is better than dead."""
import replicate.exceptions
monkeypatch.setenv("REPLICATE_API_TOKEN", "fake")
generator = ReplicateImageGenerator(runner=MagicMock(return_value=[]))
with pytest.raises(ImageGenerationError, match="no candidates"):
generator.generate("prompt")
counter = {"n": 0}

def flaky_runner(model, input):
counter["n"] += 1
if counter["n"] == 2:
raise replicate.exceptions.ReplicateError("CUDA OOM on this one")
return [f"https://r/{counter['n']}.webp"]

generator = ReplicateImageGenerator(runner=flaky_runner, num_outputs=4)
urls = generator.generate("prompt")
assert len(urls) == 3

def test_replicate_generator_wraps_replicate_errors(monkeypatch):

def test_replicate_generator_raises_when_all_calls_fail(monkeypatch):
import replicate.exceptions
monkeypatch.setenv("REPLICATE_API_TOKEN", "fake")
broken = MagicMock(side_effect=replicate.exceptions.ReplicateError("upstream 502"))
generator = ReplicateImageGenerator(runner=broken)
with pytest.raises(ImageGenerationError, match="Replicate error"):
generator = ReplicateImageGenerator(runner=broken, num_outputs=4)
with pytest.raises(ImageGenerationError, match="All 4 candidates failed"):
generator.generate("prompt")


def test_replicate_generator_times_out_stuck_predictions(monkeypatch):
"""A wedged Replicate prediction must not hang the request forever."""
import time
monkeypatch.setenv("REPLICATE_API_TOKEN", "fake")

def slow_runner(model, input):
time.sleep(5) # would hang forever in the real failure mode
return ["never seen"]

generator = ReplicateImageGenerator(
runner=slow_runner, num_outputs=2, timeout_seconds=0.1
)
with pytest.raises(ImageGenerationError, match="timed out"):
generator.generate("prompt")


def test_replicate_generator_wraps_http_errors(monkeypatch):
monkeypatch.setenv("REPLICATE_API_TOKEN", "fake")
broken = MagicMock(side_effect=httpx.ConnectError("timeout"))
generator = ReplicateImageGenerator(runner=broken)
generator = ReplicateImageGenerator(runner=broken, num_outputs=1)
with pytest.raises(ImageGenerationError, match="Network error"):
generator.generate("prompt")

Expand Down
Loading