diff --git a/app/images/providers.py b/app/images/providers.py index 1c31479..ef122ba 100644 --- a/app/images/providers.py +++ b/app/images/providers.py @@ -9,6 +9,8 @@ import logging import os import uuid +from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import TimeoutError as FutureTimeoutError from typing import Callable, List, Sequence from urllib.parse import urlparse @@ -90,51 +92,98 @@ def describe_scene(self, theme_body: str, tag_names: Sequence[str]) -> str: # ---------- Replicate image generator ---------- -DEFAULT_FLUX_MODEL = "black-forest-labs/flux-schnell" +# The canonical "black-forest-labs/flux-schnell" model is routed through a shared +# pool that has periods where predictions accept-but-never-start. The "-lora" sibling +# runs the same Flux Schnell weights on a different pool; calling it without LoRA +# inputs produces equivalent images. It does, however, OOM on num_outputs>1, so we +# fan out one prediction per candidate. +DEFAULT_FLUX_MODEL = "black-forest-labs/flux-schnell-lora" +DEFAULT_PREDICTION_TIMEOUT_SECONDS = 60.0 class ReplicateImageGenerator: - """Calls Flux Schnell on Replicate; returns temporary candidate URLs.""" + """Generates candidates via parallel Flux Schnell predictions. + + Each candidate is its own prediction with ``num_outputs=1``. This isolates failures + (a single CUDA OOM or stuck worker costs one candidate, not the whole batch) and + lets us put a wall-clock timeout around each call so the request can never hang. + """ def __init__( self, model: str = DEFAULT_FLUX_MODEL, aspect_ratio: str = "1:1", num_outputs: int = 4, + timeout_seconds: float = DEFAULT_PREDICTION_TIMEOUT_SECONDS, runner: Callable[..., list] = replicate.run, ) -> None: self._model = model self._aspect_ratio = aspect_ratio self._num_outputs = num_outputs + self._timeout_seconds = timeout_seconds self._runner = runner def generate(self, prompt: str) -> List[str]: if not os.getenv("REPLICATE_API_TOKEN"): raise ImageGenerationError("REPLICATE_API_TOKEN is not set") + pool = ThreadPoolExecutor(max_workers=self._num_outputs) + try: + futures = [ + pool.submit(self._generate_one, prompt) + for _ in range(self._num_outputs) + ] + urls: List[str] = [] + errors: List[str] = [] + for future in futures: + try: + urls.append(future.result(timeout=self._timeout_seconds)) + except FutureTimeoutError: + errors.append(f"timed out after {self._timeout_seconds:.0f}s") + except ImageGenerationError as e: + errors.append(str(e)) + finally: + # Don't block returning to the user on slow stragglers; they'll finish + # in their own time and the threadpool is GC'd. + pool.shutdown(wait=False, cancel_futures=True) + + if not urls: + logger.error( + "All %d Replicate candidates failed (prompt_len=%d): %s", + self._num_outputs, len(prompt), errors, + ) + raise ImageGenerationError( + f"All {self._num_outputs} candidates failed: {'; '.join(errors)}" + ) + if errors: + logger.warning( + "Partial Replicate failure: %d/%d candidates failed: %s", + len(errors), self._num_outputs, errors, + ) + return urls + + def _generate_one(self, prompt: str) -> str: try: output = self._runner( self._model, input={ "prompt": prompt, - "num_outputs": self._num_outputs, + "num_outputs": 1, "aspect_ratio": self._aspect_ratio, "output_format": "webp", "output_quality": 90, }, ) except replicate.exceptions.ReplicateError as e: - logger.error("Replicate API error: %s (prompt_len=%d)", e, len(prompt)) raise ImageGenerationError(f"Replicate error: {e}") from e except httpx.HTTPError as e: - logger.error("Replicate network error: %s", e) - raise ImageGenerationError(f"Network error contacting Replicate: {e}") from e + raise ImageGenerationError(f"Network error: {e}") from e - urls = [str(item.url) if hasattr(item, "url") else str(item) for item in output] - if not urls: - logger.error("Replicate returned empty output (prompt_len=%d)", len(prompt)) - raise ImageGenerationError("Replicate returned no candidates") - return urls + items = list(output) + if not items: + raise ImageGenerationError("Replicate returned empty output") + item = items[0] + return str(item.url) if hasattr(item, "url") else str(item) # ---------- R2 image store ---------- diff --git a/app/images/service.py b/app/images/service.py index 830f814..578ea82 100644 --- a/app/images/service.py +++ b/app/images/service.py @@ -13,7 +13,9 @@ "Rendered as a flat oil painting in a limited three-color palette, " "figurative and confident, contemporary museum-quality painting, " "tone balanced between gravity and play. " - "If human figures are included, represent varied skin tones, ages, and body types." + "If human figures are included, represent varied skin tones, ages, and body types. " + "When hands are visible, they rest calmly at the sides, are folded, or hold " + "a single object — never reaching, gesturing, or pointing." ) diff --git a/tests/test_images.py b/tests/test_images.py index 542980e..74225f4 100644 --- a/tests/test_images.py +++ b/tests/test_images.py @@ -245,18 +245,35 @@ def test_replicate_generator_missing_token(monkeypatch): generator.generate("prompt") -def test_replicate_generator_returns_urls(monkeypatch): +def test_replicate_generator_fans_out_one_call_per_candidate(monkeypatch): + """Each candidate is its own prediction with num_outputs=1 — avoids GPU OOM.""" monkeypatch.setenv("REPLICATE_API_TOKEN", "fake") - fake_runner = MagicMock(return_value=["https://r/1.webp", "https://r/2.webp"]) - generator = ReplicateImageGenerator(runner=fake_runner) + counter = {"n": 0} + def fake_runner(model, input): + counter["n"] += 1 + return [f"https://r/{counter['n']}.webp"] + + generator = ReplicateImageGenerator(runner=fake_runner, num_outputs=4) urls = generator.generate("a scene") - assert urls == ["https://r/1.webp", "https://r/2.webp"] + assert len(urls) == 4 + assert counter["n"] == 4 + assert set(urls) == {f"https://r/{i}.webp" for i in range(1, 5)} + + +def test_replicate_generator_sends_correct_input_shape(monkeypatch): + monkeypatch.setenv("REPLICATE_API_TOKEN", "fake") + fake_runner = MagicMock(return_value=["https://r/1.webp"]) + generator = ReplicateImageGenerator(runner=fake_runner, num_outputs=1) + + generator.generate("a scene") + call_input = fake_runner.call_args.kwargs["input"] assert call_input["prompt"] == "a scene" - assert call_input["num_outputs"] == 4 + assert call_input["num_outputs"] == 1 assert call_input["aspect_ratio"] == "1:1" + assert call_input["output_format"] == "webp" def test_replicate_generator_handles_fileoutput_objects(monkeypatch): @@ -264,31 +281,57 @@ def test_replicate_generator_handles_fileoutput_objects(monkeypatch): item = MagicMock() item.url = "https://r/1.webp" fake_runner = MagicMock(return_value=[item]) - generator = ReplicateImageGenerator(runner=fake_runner) + generator = ReplicateImageGenerator(runner=fake_runner, num_outputs=1) assert generator.generate("prompt") == ["https://r/1.webp"] -def test_replicate_generator_rejects_empty_output(monkeypatch): +def test_replicate_generator_returns_partial_results_when_some_calls_fail(monkeypatch): + """One bad candidate shouldn't fail the whole batch — degraded is better than dead.""" + import replicate.exceptions monkeypatch.setenv("REPLICATE_API_TOKEN", "fake") - generator = ReplicateImageGenerator(runner=MagicMock(return_value=[])) - with pytest.raises(ImageGenerationError, match="no candidates"): - generator.generate("prompt") + counter = {"n": 0} + + def flaky_runner(model, input): + counter["n"] += 1 + if counter["n"] == 2: + raise replicate.exceptions.ReplicateError("CUDA OOM on this one") + return [f"https://r/{counter['n']}.webp"] + generator = ReplicateImageGenerator(runner=flaky_runner, num_outputs=4) + urls = generator.generate("prompt") + assert len(urls) == 3 -def test_replicate_generator_wraps_replicate_errors(monkeypatch): + +def test_replicate_generator_raises_when_all_calls_fail(monkeypatch): import replicate.exceptions monkeypatch.setenv("REPLICATE_API_TOKEN", "fake") broken = MagicMock(side_effect=replicate.exceptions.ReplicateError("upstream 502")) - generator = ReplicateImageGenerator(runner=broken) - with pytest.raises(ImageGenerationError, match="Replicate error"): + generator = ReplicateImageGenerator(runner=broken, num_outputs=4) + with pytest.raises(ImageGenerationError, match="All 4 candidates failed"): + generator.generate("prompt") + + +def test_replicate_generator_times_out_stuck_predictions(monkeypatch): + """A wedged Replicate prediction must not hang the request forever.""" + import time + monkeypatch.setenv("REPLICATE_API_TOKEN", "fake") + + def slow_runner(model, input): + time.sleep(5) # would hang forever in the real failure mode + return ["never seen"] + + generator = ReplicateImageGenerator( + runner=slow_runner, num_outputs=2, timeout_seconds=0.1 + ) + with pytest.raises(ImageGenerationError, match="timed out"): generator.generate("prompt") def test_replicate_generator_wraps_http_errors(monkeypatch): monkeypatch.setenv("REPLICATE_API_TOKEN", "fake") broken = MagicMock(side_effect=httpx.ConnectError("timeout")) - generator = ReplicateImageGenerator(runner=broken) + generator = ReplicateImageGenerator(runner=broken, num_outputs=1) with pytest.raises(ImageGenerationError, match="Network error"): generator.generate("prompt")