Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""add image_url to theme

Revision ID: 6998ce81619a
Revises: ded91ca4b249
Create Date: 2026-04-18 18:00:33.038004

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes


# revision identifiers, used by Alembic.
revision: str = '6998ce81619a'
down_revision: Union[str, Sequence[str], None] = 'ded91ca4b249'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
"""Upgrade schema."""
op.add_column(
'theme',
sa.Column('image_url', sqlmodel.sql.sqltypes.AutoString(length=2048), nullable=True),
)


def downgrade() -> None:
"""Downgrade schema."""
op.drop_column('theme', 'image_url')
110 changes: 91 additions & 19 deletions app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,26 @@
import uuid
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Optional
from typing import List, Optional

import boto3
from fastapi import APIRouter, Depends, FastAPI, File, HTTPException, Query, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from sqlalchemy.exc import DataError, IntegrityError, OperationalError
from starlette.requests import Request
from sqlmodel import Session, col, func, or_, select
from sqlmodel import Session, SQLModel, col, func, or_, select

from app.auth import create_access_token, require_admin, verify_admin_password
from app.db import get_session
from app.images import (
ImageCommitError,
ImageGenerationError,
build_theme_prompt,
commit_image_to_r2,
generate_candidate_images,
tags_for_prompt,
)
from app.models import (
Breadcrumb,
BreadcrumbBase,
Expand All @@ -32,6 +39,7 @@
ThemeUpdate,
Visibility,
)
from app.storage import R2ConfigError, put_object


@asynccontextmanager
Expand Down Expand Up @@ -243,6 +251,80 @@ def delete_theme(
session.flush()


# ---------- theme cover image endpoints ----------


class GenerateImageResponse(SQLModel):
prompt: str
candidates: List[str]


class CommitImageRequest(SQLModel):
source_url: str


@router.post("/themes/{theme_id}/generate-image", response_model=GenerateImageResponse)
def generate_theme_image(
theme_id: int,
session: Session = Depends(get_session),
_admin: None = Depends(require_admin),
):
theme = session.get(Theme, theme_id)
if not theme:
raise HTTPException(status_code=404, detail="Theme not found")

prompt = build_theme_prompt(theme, tags_for_prompt(theme))
try:
candidates = generate_candidate_images(prompt)
except ImageGenerationError as e:
logger.warning("Image generation failed for theme %d: %s", theme_id, e)
raise HTTPException(status_code=503, detail=str(e))

return GenerateImageResponse(prompt=prompt, candidates=candidates)


@router.post("/themes/{theme_id}/image", response_model=ThemePublic)
def commit_theme_image(
theme_id: int,
body: CommitImageRequest,
session: Session = Depends(get_session),
_admin: None = Depends(require_admin),
):
theme = session.get(Theme, theme_id)
if not theme:
raise HTTPException(status_code=404, detail="Theme not found")

try:
theme.image_url = commit_image_to_r2(body.source_url)
except ImageCommitError as e:
logger.warning("Image commit failed for theme %d: %s", theme_id, e)
raise HTTPException(status_code=400, detail=str(e))
except R2ConfigError as e:
logger.error("R2 misconfigured during commit for theme %d: %s", theme_id, e)
raise HTTPException(status_code=500, detail="Storage is misconfigured")

session.add(theme)
session.flush()
session.refresh(theme)
return theme


@router.delete("/themes/{theme_id}/image", response_model=ThemePublic)
def clear_theme_image(
theme_id: int,
session: Session = Depends(get_session),
_admin: None = Depends(require_admin),
):
theme = session.get(Theme, theme_id)
if not theme:
raise HTTPException(status_code=404, detail="Theme not found")
theme.image_url = None
session.add(theme)
session.flush()
session.refresh(theme)
return theme


# ---------- breadcrumb endpoints ----------


Expand Down Expand Up @@ -410,22 +492,12 @@ def upload_image(
ext = os.path.splitext(file.filename or "")[1] or ".png"
key = f"{uuid.uuid4().hex[:12]}{ext}"

s3 = boto3.client(
"s3",
endpoint_url=f"https://{os.getenv('R2_ACCOUNT_ID')}.r2.cloudflarestorage.com",
aws_access_key_id=os.getenv("R2_ACCESS_KEY_ID"),
aws_secret_access_key=os.getenv("R2_SECRET_ACCESS_KEY"),
region_name="auto",
)
s3.put_object(
Bucket=os.getenv("R2_BUCKET_NAME"),
Key=key,
Body=contents,
ContentType=file.content_type,
)

public_url = f"{os.getenv('R2_PUBLIC_URL', '').rstrip('/')}/{key}"
return {"url": public_url}
try:
url = put_object(key, contents, file.content_type)
except R2ConfigError as e:
logger.error("R2 misconfigured during upload: %s", e)
raise HTTPException(status_code=500, detail="Storage is misconfigured")
return {"url": url}


# ---------- app assembly ----------
Expand Down
167 changes: 167 additions & 0 deletions app/images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""Theme cover image generation via Replicate (Flux Schnell) + R2 storage."""

import logging
import os
import uuid
from typing import List
from urllib.parse import urlparse

import httpx
import replicate
import replicate.exceptions
from dotenv import load_dotenv

from app.models import Theme
from app.storage import put_object

load_dotenv()

logger = logging.getLogger(__name__)


FLUX_SCHNELL_MODEL = "black-forest-labs/flux-schnell"
DEFAULT_ASPECT_RATIO = "1:1"
DEFAULT_NUM_OUTPUTS = 4

# Replicate delivery hosts — allowlist for SSRF defense on commit.
ALLOWED_SOURCE_HOSTS = frozenset({"replicate.delivery", "pbxt.replicate.delivery"})

# Content-type → extension map. Unknown types are rejected rather than defaulted,
# so we don't silently store non-image bytes as .webp.
CONTENT_TYPE_EXTENSIONS = {
"image/webp": ".webp",
"image/png": ".png",
"image/jpeg": ".jpg",
}

# Magic bytes for the formats we accept.
IMAGE_MAGIC_BYTES = (
b"RIFF", # webp (followed by size + "WEBP" at offset 8)
b"\x89PNG", # png
b"\xff\xd8\xff", # jpeg
)

STYLE_SUFFIX = (
"Rendered as a flat oil painting in a limited three-color palette, "
"figurative and confident, contemporary museum-quality painting. "
"When human figures appear, depict a diversity of people — "
"including people of color, varied ages, and varied body types. "
"Tone balanced between gravity and play."
)


class ImageGenerationError(RuntimeError):
"""Upstream Replicate error — transient, tell the writer to retry."""


class ImageCommitError(RuntimeError):
"""Error downloading or validating a chosen candidate image."""


def build_theme_prompt(theme: Theme, tag_names: List[str]) -> str:
"""Translate a theme + tags into a natural-language Flux prompt.

Scene framing + theme snippet + tag mood + fixed style suffix.
"""
body = (theme.body_md or "").strip()
first_sentence = body.split(".")[0].strip()
snippet = first_sentence if 0 < len(first_sentence) <= 200 else body[:200].strip()

tags_phrase = ""
if tag_names:
readable = ", ".join(name.replace("-", " ") for name in tag_names)
tags_phrase = f" Mood draws from: {readable}."

return (
f'A single figurative scene that evokes the theme: "{snippet}".'
f"{tags_phrase} {STYLE_SUFFIX}"
)


def generate_candidate_images(
prompt: str,
num_outputs: int = DEFAULT_NUM_OUTPUTS,
aspect_ratio: str = DEFAULT_ASPECT_RATIO,
) -> List[str]:
"""Call Flux Schnell on Replicate, return a list of temporary image URLs."""
if not os.getenv("REPLICATE_API_TOKEN"):
raise ImageGenerationError("REPLICATE_API_TOKEN is not set")

try:
output = replicate.run(
FLUX_SCHNELL_MODEL,
input={
"prompt": prompt,
"num_outputs": num_outputs,
"aspect_ratio": aspect_ratio,
"output_format": "webp",
"output_quality": 90,
},
)
except replicate.exceptions.ReplicateError as e:
logger.error("Replicate API error: %s (prompt_len=%d)", e, len(prompt))
raise ImageGenerationError(f"Replicate error: {e}") from e
except httpx.HTTPError as e:
logger.error("Replicate network error: %s (prompt_len=%d)", e, len(prompt))
raise ImageGenerationError(f"Network error contacting Replicate: {e}") from e

urls = [str(item.url) if hasattr(item, "url") else str(item) for item in output]
if not urls:
logger.error("Replicate returned empty output (prompt_len=%d)", len(prompt))
raise ImageGenerationError("Replicate returned no candidates")
return urls


def _is_allowed_source(source_url: str) -> bool:
"""Validate source_url is HTTPS and hosted on an allowlisted Replicate domain."""
parsed = urlparse(source_url)
if parsed.scheme != "https":
return False
host = parsed.hostname or ""
return host in ALLOWED_SOURCE_HOSTS or any(
host.endswith(f".{h}") for h in ALLOWED_SOURCE_HOSTS
)


def _looks_like_image(data: bytes) -> bool:
return any(data.startswith(magic) for magic in IMAGE_MAGIC_BYTES)


def commit_image_to_r2(source_url: str) -> str:
"""Download an allowlisted Replicate URL and re-upload to R2.

Rejects non-HTTPS URLs, hosts outside the Replicate allowlist, redirects,
unknown content types, and payloads that don't start with a known image
magic-byte sequence. Returns the permanent public R2 URL.
"""
if not _is_allowed_source(source_url):
raise ImageCommitError(
f"source_url host not in allowlist: {urlparse(source_url).hostname}"
)

try:
response = httpx.get(source_url, follow_redirects=False, timeout=30)
response.raise_for_status()
except httpx.HTTPError as e:
logger.error("Failed to download candidate image from %s: %s", source_url, e)
raise ImageCommitError(f"Could not download candidate: {e}") from e

data = response.content
content_type = response.headers.get("content-type", "").split(";")[0].strip()

ext = CONTENT_TYPE_EXTENSIONS.get(content_type)
if ext is None:
logger.error("Rejecting unknown content-type from %s: %s", source_url, content_type)
raise ImageCommitError(f"Unsupported content-type: {content_type!r}")

if not _looks_like_image(data):
logger.error("Payload from %s did not match image magic bytes", source_url)
raise ImageCommitError("Downloaded payload is not a recognized image")

key = f"theme-{uuid.uuid4().hex[:12]}{ext}"
return put_object(key, data, content_type)


def tags_for_prompt(theme: Theme) -> List[str]:
"""Extract tag names safely; theme.tags may be empty or not loaded."""
return [t.name for t in (theme.tags or [])]
6 changes: 6 additions & 0 deletions app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ class ThemeBase(SQLModel, table=False):
visibility: Visibility = Field(
default=Visibility.draft, description="The theme's status (draft or published)"
)
image_url: Optional[str] = Field(
default=None,
max_length=2048,
description="Public R2 URL for the theme's generated cover image",
)
created_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
description="When this theme was created",
Expand Down Expand Up @@ -76,6 +81,7 @@ class ThemeCreate(ThemeBase, table=False):
class ThemeUpdate(SQLModel, table=False):
body_md: Optional[str] = None
visibility: Optional[Visibility] = None
image_url: Optional[str] = None
tags: Optional[List["TagCreate"]] = None


Expand Down
Loading
Loading