diff --git a/alembic/versions/2026_04_18_1800-6998ce81619a_add_image_url_to_theme.py b/alembic/versions/2026_04_18_1800-6998ce81619a_add_image_url_to_theme.py new file mode 100644 index 0000000..3af215a --- /dev/null +++ b/alembic/versions/2026_04_18_1800-6998ce81619a_add_image_url_to_theme.py @@ -0,0 +1,32 @@ +"""add image_url to theme + +Revision ID: 6998ce81619a +Revises: ded91ca4b249 +Create Date: 2026-04-18 18:00:33.038004 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision: str = '6998ce81619a' +down_revision: Union[str, Sequence[str], None] = 'ded91ca4b249' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + op.add_column( + 'theme', + sa.Column('image_url', sqlmodel.sql.sqltypes.AutoString(length=2048), nullable=True), + ) + + +def downgrade() -> None: + """Downgrade schema.""" + op.drop_column('theme', 'image_url') diff --git a/app/api.py b/app/api.py index 7fb8355..1b6e679 100644 --- a/app/api.py +++ b/app/api.py @@ -4,19 +4,26 @@ import uuid from contextlib import asynccontextmanager from pathlib import Path -from typing import Optional +from typing import List, Optional -import boto3 from fastapi import APIRouter, Depends, FastAPI, File, HTTPException, Query, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, JSONResponse from fastapi.staticfiles import StaticFiles from sqlalchemy.exc import DataError, IntegrityError, OperationalError from starlette.requests import Request -from sqlmodel import Session, col, func, or_, select +from sqlmodel import Session, SQLModel, col, func, or_, select from app.auth import create_access_token, require_admin, verify_admin_password from app.db import get_session +from app.images import ( + ImageCommitError, + ImageGenerationError, + build_theme_prompt, + commit_image_to_r2, + generate_candidate_images, + tags_for_prompt, +) from app.models import ( Breadcrumb, BreadcrumbBase, @@ -32,6 +39,7 @@ ThemeUpdate, Visibility, ) +from app.storage import R2ConfigError, put_object @asynccontextmanager @@ -243,6 +251,80 @@ def delete_theme( session.flush() +# ---------- theme cover image endpoints ---------- + + +class GenerateImageResponse(SQLModel): + prompt: str + candidates: List[str] + + +class CommitImageRequest(SQLModel): + source_url: str + + +@router.post("/themes/{theme_id}/generate-image", response_model=GenerateImageResponse) +def generate_theme_image( + theme_id: int, + session: Session = Depends(get_session), + _admin: None = Depends(require_admin), +): + theme = session.get(Theme, theme_id) + if not theme: + raise HTTPException(status_code=404, detail="Theme not found") + + prompt = build_theme_prompt(theme, tags_for_prompt(theme)) + try: + candidates = generate_candidate_images(prompt) + except ImageGenerationError as e: + logger.warning("Image generation failed for theme %d: %s", theme_id, e) + raise HTTPException(status_code=503, detail=str(e)) + + return GenerateImageResponse(prompt=prompt, candidates=candidates) + + +@router.post("/themes/{theme_id}/image", response_model=ThemePublic) +def commit_theme_image( + theme_id: int, + body: CommitImageRequest, + session: Session = Depends(get_session), + _admin: None = Depends(require_admin), +): + theme = session.get(Theme, theme_id) + if not theme: + raise HTTPException(status_code=404, detail="Theme not found") + + try: + theme.image_url = commit_image_to_r2(body.source_url) + except ImageCommitError as e: + logger.warning("Image commit failed for theme %d: %s", theme_id, e) + raise HTTPException(status_code=400, detail=str(e)) + except R2ConfigError as e: + logger.error("R2 misconfigured during commit for theme %d: %s", theme_id, e) + raise HTTPException(status_code=500, detail="Storage is misconfigured") + + session.add(theme) + session.flush() + session.refresh(theme) + return theme + + +@router.delete("/themes/{theme_id}/image", response_model=ThemePublic) +def clear_theme_image( + theme_id: int, + session: Session = Depends(get_session), + _admin: None = Depends(require_admin), +): + theme = session.get(Theme, theme_id) + if not theme: + raise HTTPException(status_code=404, detail="Theme not found") + theme.image_url = None + session.add(theme) + session.flush() + session.refresh(theme) + return theme + + # ---------- breadcrumb endpoints ---------- @@ -410,22 +492,12 @@ def upload_image( ext = os.path.splitext(file.filename or "")[1] or ".png" key = f"{uuid.uuid4().hex[:12]}{ext}" - s3 = boto3.client( - "s3", - endpoint_url=f"https://{os.getenv('R2_ACCOUNT_ID')}.r2.cloudflarestorage.com", - aws_access_key_id=os.getenv("R2_ACCESS_KEY_ID"), - aws_secret_access_key=os.getenv("R2_SECRET_ACCESS_KEY"), - region_name="auto", - ) - s3.put_object( - Bucket=os.getenv("R2_BUCKET_NAME"), - Key=key, - Body=contents, - ContentType=file.content_type, - ) - - public_url = f"{os.getenv('R2_PUBLIC_URL', '').rstrip('/')}/{key}" - return {"url": public_url} + try: + url = put_object(key, contents, file.content_type) + except R2ConfigError as e: + logger.error("R2 misconfigured during upload: %s", e) + raise HTTPException(status_code=500, detail="Storage is misconfigured") + return {"url": url} # ---------- app assembly ---------- diff --git a/app/images.py b/app/images.py new file mode 100644 index 0000000..5ac0698 --- /dev/null +++ b/app/images.py @@ -0,0 +1,167 @@ +"""Theme cover image generation via Replicate (Flux Schnell) + R2 storage.""" + +import logging +import os +import uuid +from typing import List +from urllib.parse import urlparse + +import httpx +import replicate +import replicate.exceptions +from dotenv import load_dotenv + +from app.models import Theme +from app.storage import put_object + +load_dotenv() + +logger = logging.getLogger(__name__) + + +FLUX_SCHNELL_MODEL = "black-forest-labs/flux-schnell" +DEFAULT_ASPECT_RATIO = "1:1" +DEFAULT_NUM_OUTPUTS = 4 + +# Replicate delivery hosts — allowlist for SSRF defense on commit. +ALLOWED_SOURCE_HOSTS = frozenset({"replicate.delivery", "pbxt.replicate.delivery"}) + +# Content-type → extension map. Unknown types are rejected rather than defaulted, +# so we don't silently store non-image bytes as .webp. +CONTENT_TYPE_EXTENSIONS = { + "image/webp": ".webp", + "image/png": ".png", + "image/jpeg": ".jpg", +} + +# Magic bytes for the formats we accept. +IMAGE_MAGIC_BYTES = ( + b"RIFF", # webp (followed by size + "WEBP" at offset 8) + b"\x89PNG", # png + b"\xff\xd8\xff", # jpeg +) + +STYLE_SUFFIX = ( + "Rendered as a flat oil painting in a limited three-color palette, " + "figurative and confident, contemporary museum-quality painting. " + "When human figures appear, depict a diversity of people — " + "including people of color, varied ages, and varied body types. " + "Tone balanced between gravity and play." +) + + +class ImageGenerationError(RuntimeError): + """Upstream Replicate error — transient, tell the writer to retry.""" + + +class ImageCommitError(RuntimeError): + """Error downloading or validating a chosen candidate image.""" + + +def build_theme_prompt(theme: Theme, tag_names: List[str]) -> str: + """Translate a theme + tags into a natural-language Flux prompt. + + Scene framing + theme snippet + tag mood + fixed style suffix. + """ + body = (theme.body_md or "").strip() + first_sentence = body.split(".")[0].strip() + snippet = first_sentence if 0 < len(first_sentence) <= 200 else body[:200].strip() + + tags_phrase = "" + if tag_names: + readable = ", ".join(name.replace("-", " ") for name in tag_names) + tags_phrase = f" Mood draws from: {readable}." + + return ( + f'A single figurative scene that evokes the theme: "{snippet}".' + f"{tags_phrase} {STYLE_SUFFIX}" + ) + + +def generate_candidate_images( + prompt: str, + num_outputs: int = DEFAULT_NUM_OUTPUTS, + aspect_ratio: str = DEFAULT_ASPECT_RATIO, +) -> List[str]: + """Call Flux Schnell on Replicate, return a list of temporary image URLs.""" + if not os.getenv("REPLICATE_API_TOKEN"): + raise ImageGenerationError("REPLICATE_API_TOKEN is not set") + + try: + output = replicate.run( + FLUX_SCHNELL_MODEL, + input={ + "prompt": prompt, + "num_outputs": num_outputs, + "aspect_ratio": aspect_ratio, + "output_format": "webp", + "output_quality": 90, + }, + ) + except replicate.exceptions.ReplicateError as e: + logger.error("Replicate API error: %s (prompt_len=%d)", e, len(prompt)) + raise ImageGenerationError(f"Replicate error: {e}") from e + except httpx.HTTPError as e: + logger.error("Replicate network error: %s (prompt_len=%d)", e, len(prompt)) + raise ImageGenerationError(f"Network error contacting Replicate: {e}") from e + + urls = [str(item.url) if hasattr(item, "url") else str(item) for item in output] + if not urls: + logger.error("Replicate returned empty output (prompt_len=%d)", len(prompt)) + raise ImageGenerationError("Replicate returned no candidates") + return urls + + +def _is_allowed_source(source_url: str) -> bool: + """Validate source_url is HTTPS and hosted on an allowlisted Replicate domain.""" + parsed = urlparse(source_url) + if parsed.scheme != "https": + return False + host = parsed.hostname or "" + return host in ALLOWED_SOURCE_HOSTS or any( + host.endswith(f".{h}") for h in ALLOWED_SOURCE_HOSTS + ) + + +def _looks_like_image(data: bytes) -> bool: + return any(data.startswith(magic) for magic in IMAGE_MAGIC_BYTES) + + +def commit_image_to_r2(source_url: str) -> str: + """Download an allowlisted Replicate URL and re-upload to R2. + + Rejects non-HTTPS URLs, hosts outside the Replicate allowlist, redirects, + unknown content types, and payloads that don't start with a known image + magic-byte sequence. Returns the permanent public R2 URL. + """ + if not _is_allowed_source(source_url): + raise ImageCommitError( + f"source_url host not in allowlist: {urlparse(source_url).hostname}" + ) + + try: + response = httpx.get(source_url, follow_redirects=False, timeout=30) + response.raise_for_status() + except httpx.HTTPError as e: + logger.error("Failed to download candidate image from %s: %s", source_url, e) + raise ImageCommitError(f"Could not download candidate: {e}") from e + + data = response.content + content_type = response.headers.get("content-type", "").split(";")[0].strip() + + ext = CONTENT_TYPE_EXTENSIONS.get(content_type) + if ext is None: + logger.error("Rejecting unknown content-type from %s: %s", source_url, content_type) + raise ImageCommitError(f"Unsupported content-type: {content_type!r}") + + if not _looks_like_image(data): + logger.error("Payload from %s did not match image magic bytes", source_url) + raise ImageCommitError("Downloaded payload is not a recognized image") + + key = f"theme-{uuid.uuid4().hex[:12]}{ext}" + return put_object(key, data, content_type) + + +def tags_for_prompt(theme: Theme) -> List[str]: + """Extract tag names safely; theme.tags may be empty or not loaded.""" + return [t.name for t in (theme.tags or [])] diff --git a/app/models.py b/app/models.py index e50a797..30eb0cb 100644 --- a/app/models.py +++ b/app/models.py @@ -33,6 +33,11 @@ class ThemeBase(SQLModel, table=False): visibility: Visibility = Field( default=Visibility.draft, description="The theme's status (draft or published)" ) + image_url: Optional[str] = Field( + default=None, + max_length=2048, + description="Public R2 URL for the theme's generated cover image", + ) created_at: datetime = Field( default_factory=lambda: datetime.now(timezone.utc), description="When this theme was created", @@ -76,6 +81,7 @@ class ThemeCreate(ThemeBase, table=False): class ThemeUpdate(SQLModel, table=False): body_md: Optional[str] = None visibility: Optional[Visibility] = None + image_url: Optional[str] = None tags: Optional[List["TagCreate"]] = None diff --git a/app/storage.py b/app/storage.py new file mode 100644 index 0000000..ce30441 --- /dev/null +++ b/app/storage.py @@ -0,0 +1,59 @@ +"""Cloudflare R2 (S3-compatible) storage helpers for uploads and generated images.""" + +import logging +import os +from typing import Optional + +import boto3 +from dotenv import load_dotenv + +load_dotenv() + +logger = logging.getLogger(__name__) + +REQUIRED_R2_VARS = ( + "R2_ACCOUNT_ID", + "R2_ACCESS_KEY_ID", + "R2_SECRET_ACCESS_KEY", + "R2_BUCKET_NAME", + "R2_PUBLIC_URL", +) + + +class R2ConfigError(RuntimeError): + """Raised when required R2 environment variables are missing.""" + + +def assert_r2_env() -> None: + missing = [v for v in REQUIRED_R2_VARS if not os.getenv(v)] + if missing: + raise R2ConfigError(f"Missing R2 env vars: {', '.join(missing)}") + + +def _client(): + return boto3.client( + "s3", + endpoint_url=f"https://{os.getenv('R2_ACCOUNT_ID')}.r2.cloudflarestorage.com", + aws_access_key_id=os.getenv("R2_ACCESS_KEY_ID"), + aws_secret_access_key=os.getenv("R2_SECRET_ACCESS_KEY"), + region_name="auto", + ) + + +def public_url(key: str) -> str: + base: Optional[str] = os.getenv("R2_PUBLIC_URL") + if not base: + raise R2ConfigError("R2_PUBLIC_URL is not set") + return f"{base.rstrip('/')}/{key}" + + +def put_object(key: str, body: bytes, content_type: str) -> str: + """Upload bytes to R2 under `key` and return the permanent public URL.""" + assert_r2_env() + _client().put_object( + Bucket=os.getenv("R2_BUCKET_NAME"), + Key=key, + Body=body, + ContentType=content_type, + ) + return public_url(key) diff --git a/frontend/src/components/theme-section.tsx b/frontend/src/components/theme-section.tsx index aea0770..98c6a51 100644 --- a/frontend/src/components/theme-section.tsx +++ b/frontend/src/components/theme-section.tsx @@ -11,9 +11,10 @@ import { cn } from "@/lib/utils" interface ThemeSectionProps { theme: ThemePublic + variant?: "feed" | "permalink" } -export function ThemeSection({ theme }: ThemeSectionProps) { +export function ThemeSection({ theme, variant = "feed" }: ThemeSectionProps) { const { data: breadcrumbs, isLoading, @@ -30,7 +31,21 @@ export function ThemeSection({ theme }: ThemeSectionProps) { return (
-
+ {theme.image_url && variant === "permalink" && ( + + )} +
+ {theme.image_url && variant === "feed" && ( + + )}
{theme.body_md}
diff --git a/frontend/src/components/writer/theme-image-picker.tsx b/frontend/src/components/writer/theme-image-picker.tsx new file mode 100644 index 0000000..7c28c9a --- /dev/null +++ b/frontend/src/components/writer/theme-image-picker.tsx @@ -0,0 +1,124 @@ +import { useState } from "react" +import { useMutation, useQueryClient } from "@tanstack/react-query" +import { Sparkles, Trash2 } from "lucide-react" +import { Button } from "@/components/ui/button" +import { Label } from "@/components/ui/label" +import { + clearThemeImage, + commitThemeImage, + generateThemeImage, +} from "@/lib/api" +import type { ThemePublic } from "@/lib/types" + +interface ThemeImagePickerProps { + theme: ThemePublic +} + +export function ThemeImagePicker({ theme }: ThemeImagePickerProps) { + const queryClient = useQueryClient() + const [candidates, setCandidates] = useState([]) + const [prompt, setPrompt] = useState(null) + + const generate = useMutation({ + mutationFn: () => generateThemeImage(theme.id), + onSuccess: (res) => { + setCandidates(res.candidates) + setPrompt(res.prompt) + }, + }) + + const commit = useMutation({ + mutationFn: (sourceUrl: string) => commitThemeImage(theme.id, sourceUrl), + onSuccess: (updated) => { + queryClient.setQueryData(["themes", theme.id], updated) + queryClient.invalidateQueries({ queryKey: ["themes"] }) + setCandidates([]) + setPrompt(null) + }, + }) + + const clearImage = useMutation({ + mutationFn: () => clearThemeImage(theme.id), + onSuccess: (updated) => { + queryClient.setQueryData(["themes", theme.id], updated) + queryClient.invalidateQueries({ queryKey: ["themes"] }) + }, + }) + + return ( +
+
+ +
+ + {theme.image_url && ( + + )} +
+
+ + {theme.image_url && candidates.length === 0 && ( + + )} + + {generate.error && ( +

{generate.error.message}

+ )} + {commit.error && ( +

{commit.error.message}

+ )} + {clearImage.error && ( +

{clearImage.error.message}

+ )} + + {candidates.length > 0 && ( +
+

+ Click one to save. {prompt && Prompt: {prompt}} +

+
+ {candidates.map((url) => ( + + ))} +
+
+ )} +
+ ) +} diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts index 6c10921..894a7a8 100644 --- a/frontend/src/lib/api.ts +++ b/frontend/src/lib/api.ts @@ -4,6 +4,7 @@ import type { BreadcrumbPublic, DigestPublic, DigestType, + GenerateImageResponse, TagWithCount, ThemeCreateInput, ThemePublic, @@ -172,6 +173,31 @@ export function deleteTheme(themeId: number): Promise { }) } +export function generateThemeImage(themeId: number): Promise { + return apiMutate(`/api/themes/${themeId}/generate-image`, { + method: "POST", + label: "generate theme image", + }) +} + +export function commitThemeImage( + themeId: number, + sourceUrl: string, +): Promise { + return apiMutate(`/api/themes/${themeId}/image`, { + method: "POST", + body: { source_url: sourceUrl }, + label: "commit theme image", + }) +} + +export function clearThemeImage(themeId: number): Promise { + return apiMutate(`/api/themes/${themeId}/image`, { + method: "DELETE", + label: "clear theme image", + }) +} + // --------------------------------------------------------------------------- // Breadcrumb mutations // --------------------------------------------------------------------------- diff --git a/frontend/src/lib/types.ts b/frontend/src/lib/types.ts index b513e0b..9e9e0b2 100644 --- a/frontend/src/lib/types.ts +++ b/frontend/src/lib/types.ts @@ -17,6 +17,7 @@ export interface ThemePublic { id: number body_md: string visibility: Visibility + image_url: string | null created_at: string updated_at: string | null tags: TagPublic[] @@ -47,13 +48,19 @@ export interface ThemeCreateInput { tags?: { name: string }[] } -/** Input for PUT /themes/{id} — all fields optional (partial update) */ +/** Input for PUT /themes/{id} — all fields optional (partial update). image_url is set/cleared via dedicated endpoints. */ export interface ThemeUpdateInput { body_md?: string visibility?: Visibility tags?: { name: string }[] } +/** Response from POST /themes/{id}/generate-image */ +export interface GenerateImageResponse { + prompt: string + candidates: string[] +} + /** Input for POST breadcrumbs */ export interface BreadcrumbInput { body_md: string diff --git a/frontend/src/routes/themes.$themeId.tsx b/frontend/src/routes/themes.$themeId.tsx index fb35965..bf95502 100644 --- a/frontend/src/routes/themes.$themeId.tsx +++ b/frontend/src/routes/themes.$themeId.tsx @@ -60,7 +60,7 @@ function ThemePermalink() { return (
- +
) } diff --git a/frontend/src/routes/writer/themes.$themeId.tsx b/frontend/src/routes/writer/themes.$themeId.tsx index 40fd0b0..29730da 100644 --- a/frontend/src/routes/writer/themes.$themeId.tsx +++ b/frontend/src/routes/writer/themes.$themeId.tsx @@ -4,6 +4,7 @@ import { useQuery } from "@tanstack/react-query" import { fetchTheme, fetchBreadcrumbs } from "@/lib/api" import { buildTree } from "@/lib/tree" import { ThemeHeaderEditor } from "@/components/writer/theme-header-editor" +import { ThemeImagePicker } from "@/components/writer/theme-image-picker" import { BreadcrumbItem } from "@/components/writer/breadcrumb-item" import { AddBreadcrumbForm } from "@/components/writer/add-breadcrumb-form" import { Separator } from "@/components/ui/separator" @@ -81,6 +82,8 @@ function ThemeEditor() {
+ +
diff --git a/pyproject.toml b/pyproject.toml index 2f0f1c3..eb1cd36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,8 @@ dependencies = [ "tenacity>=9.1.0", "apscheduler>=3.11.2", "boto3>=1.35.0", + "replicate>=1.0.7", + "httpx>=0.28.1", ] [project.scripts] diff --git a/tests/test_api_themes.py b/tests/test_api_themes.py index da43a22..6be6545 100644 --- a/tests/test_api_themes.py +++ b/tests/test_api_themes.py @@ -262,3 +262,133 @@ def test_delete_theme_leaves_tags_intact(client, session): def test_delete_theme_not_found(client): response = client.delete("/api/themes/999") assert response.status_code == 404 + + +# ---------- theme cover image endpoints ---------- + + +def test_generate_theme_image_not_found(client): + response = client.post("/api/themes/999/generate-image") + assert response.status_code == 404 + + +def test_commit_theme_image_not_found(client): + response = client.post( + "/api/themes/999/image", + json={"source_url": "https://replicate.delivery/x/out.webp"}, + ) + assert response.status_code == 404 + + +def test_clear_theme_image_not_found(client): + response = client.delete("/api/themes/999/image") + assert response.status_code == 404 + + +def test_update_theme_cannot_set_image_url_via_put(client): + """image_url is not in THEME_UPDATABLE_FIELDS — must use the dedicated endpoints.""" + r = client.post("/api/themes", json={"body_md": "Theme"}) + theme_id = r.json()["id"] + + # Pydantic accepts the field on input but the whitelist rejects it. + response = client.put( + f"/api/themes/{theme_id}", json={"image_url": "https://evil.example/x.png"} + ) + assert response.status_code == 400 + assert "image_url" in response.json()["detail"] + + +def test_generate_theme_image_missing_token(client, monkeypatch): + monkeypatch.delenv("REPLICATE_API_TOKEN", raising=False) + r = client.post("/api/themes", json={"body_md": "Theme"}) + theme_id = r.json()["id"] + + response = client.post(f"/api/themes/{theme_id}/generate-image") + assert response.status_code == 503 + assert "REPLICATE_API_TOKEN" in response.json()["detail"] + + +def test_generate_theme_image_happy_path(client, monkeypatch): + monkeypatch.setenv("REPLICATE_API_TOKEN", "fake-token") + fake_urls = [f"https://replicate.delivery/x/out-{i}.webp" for i in range(4)] + monkeypatch.setattr("app.images.replicate.run", lambda *a, **k: fake_urls) + + r = client.post( + "/api/themes", + json={"body_md": "On the texture of an afternoon", "tags": [{"name": "slowness"}]}, + ) + theme_id = r.json()["id"] + + response = client.post(f"/api/themes/{theme_id}/generate-image") + assert response.status_code == 200 + body = response.json() + assert body["candidates"] == fake_urls + assert "On the texture of an afternoon" in body["prompt"] + assert "Mood draws from: slowness" in body["prompt"] + + +def test_commit_theme_image_rejects_non_replicate_host(client): + r = client.post("/api/themes", json={"body_md": "Theme"}) + theme_id = r.json()["id"] + + response = client.post( + f"/api/themes/{theme_id}/image", + json={"source_url": "https://evil.example/out.webp"}, + ) + assert response.status_code == 400 + assert "allowlist" in response.json()["detail"] + + +def test_commit_theme_image_rejects_non_https(client): + r = client.post("/api/themes", json={"body_md": "Theme"}) + theme_id = r.json()["id"] + + response = client.post( + f"/api/themes/{theme_id}/image", + json={"source_url": "http://replicate.delivery/x/out.webp"}, + ) + assert response.status_code == 400 + + +def test_commit_theme_image_happy_path(client, monkeypatch, session): + """Mock commit_image_to_r2 at the api seam to avoid network + R2.""" + monkeypatch.setattr( + "app.api.commit_image_to_r2", + lambda url: "https://cdn.example/permanent.webp", + ) + + r = client.post("/api/themes", json={"body_md": "Theme"}) + theme_id = r.json()["id"] + + response = client.post( + f"/api/themes/{theme_id}/image", + json={"source_url": "https://replicate.delivery/x/out.webp"}, + ) + assert response.status_code == 200 + assert response.json()["image_url"] == "https://cdn.example/permanent.webp" + + # Persisted on the model. + session.expire_all() + refreshed = session.get(Theme, theme_id) + assert refreshed.image_url == "https://cdn.example/permanent.webp" + + +def test_clear_theme_image_unsets_field(client, monkeypatch, session): + monkeypatch.setattr( + "app.api.commit_image_to_r2", + lambda url: "https://cdn.example/permanent.webp", + ) + r = client.post("/api/themes", json={"body_md": "Theme"}) + theme_id = r.json()["id"] + client.post( + f"/api/themes/{theme_id}/image", + json={"source_url": "https://replicate.delivery/x/out.webp"}, + ) + + response = client.delete(f"/api/themes/{theme_id}/image") + assert response.status_code == 200 + assert response.json()["image_url"] is None + + session.expire_all() + refreshed = session.get(Theme, theme_id) + assert refreshed.image_url is None diff --git a/tests/test_images.py b/tests/test_images.py new file mode 100644 index 0000000..4b4df80 --- /dev/null +++ b/tests/test_images.py @@ -0,0 +1,100 @@ +"""Unit tests for app.images pure-function logic.""" + +import pytest + +from app.images import ( + STYLE_SUFFIX, + ImageCommitError, + _is_allowed_source, + _looks_like_image, + build_theme_prompt, + tags_for_prompt, +) +from app.models import Tag, Theme + + +def test_build_theme_prompt_includes_snippet_and_suffix(): + theme = Theme(body_md="A short, sharp thought") + prompt = build_theme_prompt(theme, []) + assert "A short, sharp thought" in prompt + assert STYLE_SUFFIX in prompt + + +def test_build_theme_prompt_uses_first_sentence(): + theme = Theme(body_md="First sentence here. Second sentence we drop.") + prompt = build_theme_prompt(theme, []) + assert "First sentence here" in prompt + assert "Second sentence we drop" not in prompt + + +def test_build_theme_prompt_truncates_overlong_first_sentence(): + long = "x" * 500 + theme = Theme(body_md=long) + prompt = build_theme_prompt(theme, []) + snippet = long[:200] + assert snippet in prompt + assert "x" * 201 not in prompt + + +def test_build_theme_prompt_renders_tags_as_mood(): + theme = Theme(body_md="Body") + prompt = build_theme_prompt(theme, ["deep-work", "slow-mornings"]) + assert "Mood draws from: deep work, slow mornings" in prompt + + +def test_build_theme_prompt_no_tags_no_mood_phrase(): + theme = Theme(body_md="Body") + prompt = build_theme_prompt(theme, []) + assert "Mood draws from" not in prompt + + +def test_build_theme_prompt_empty_body(): + theme = Theme(body_md="") + prompt = build_theme_prompt(theme, []) + assert STYLE_SUFFIX in prompt + + +def test_tags_for_prompt_empty_when_none(): + theme = Theme(body_md="x") + theme.tags = [] + assert tags_for_prompt(theme) == [] + + +def test_tags_for_prompt_extracts_names(): + theme = Theme(body_md="x") + theme.tags = [Tag(name="alpha"), Tag(name="beta")] + assert tags_for_prompt(theme) == ["alpha", "beta"] + + +@pytest.mark.parametrize( + "url,expected", + [ + ("https://replicate.delivery/xezq/abc/out-0.webp", True), + ("https://pbxt.replicate.delivery/foo/bar.webp", True), + ("https://something.replicate.delivery/x.webp", True), + ("http://replicate.delivery/x.webp", False), # not https + ("https://evil.example.com/x.webp", False), + ("https://localhost/x.webp", False), + ("https://replicate.delivery.evil.com/x.webp", False), + ("file:///etc/passwd", False), + ], +) +def test_is_allowed_source_validates_host_and_scheme(url, expected): + assert _is_allowed_source(url) is expected + + +def test_looks_like_image_accepts_known_magic_bytes(): + assert _looks_like_image(b"\x89PNG\r\n\x1a\n" + b"\x00" * 8) + assert _looks_like_image(b"\xff\xd8\xff\xe0" + b"\x00" * 8) + assert _looks_like_image(b"RIFF\x00\x00\x00\x00WEBP") + + +def test_looks_like_image_rejects_html_and_text(): + assert not _looks_like_image(b"Not an image") + assert not _looks_like_image(b'{"error": "expired"}') + assert not _looks_like_image(b"") + + +def test_image_commit_error_is_runtime_error(): + """Sanity: ImageCommitError is catchable as RuntimeError for backwards compat.""" + assert issubclass(ImageCommitError, RuntimeError) diff --git a/uv.lock b/uv.lock index 068a2e1..cdd433d 100644 --- a/uv.lock +++ b/uv.lock @@ -118,10 +118,12 @@ dependencies = [ { name = "dotenv" }, { name = "email-validator" }, { name = "fastapi", extra = ["standard"] }, + { name = "httpx" }, { name = "mistune" }, { name = "psycopg", extra = ["binary"] }, { name = "pydantic" }, { name = "pyjwt" }, + { name = "replicate" }, { name = "resend" }, { name = "sqlalchemy" }, { name = "sqlmodel" }, @@ -143,10 +145,12 @@ requires-dist = [ { name = "dotenv", specifier = ">=0.9.9" }, { name = "email-validator", specifier = ">=2.1.0" }, { name = "fastapi", extras = ["standard"], specifier = "~=0.120.1" }, + { name = "httpx", specifier = ">=0.28.1" }, { name = "mistune", specifier = ">=3.1.0" }, { name = "psycopg", extras = ["binary"], specifier = ">=3.2.12" }, { name = "pydantic", specifier = ">=2.12.3" }, { name = "pyjwt", specifier = ">=2.8.0" }, + { name = "replicate", specifier = ">=1.0.7" }, { name = "resend", specifier = ">=2.7.0" }, { name = "sqlalchemy", specifier = ">=2.0.44" }, { name = "sqlmodel", specifier = ">=0.0.27" }, @@ -856,6 +860,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" }, ] +[[package]] +name = "replicate" +version = "1.0.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "httpx" }, + { name = "packaging" }, + { name = "pydantic" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4b/fd/caf6c59a6b8007366bd52ab5a320bf8d828f3860a60039309cfc0e375ec9/replicate-1.0.7.tar.gz", hash = "sha256:d88cb2c37ba39fb370c87fc3291601c67aae64bb918a20a85b5ce399c23ee84c", size = 62226, upload-time = "2025-05-27T11:29:08.111Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/5a/b3aa02a11a33de08e7771579154af3193decfb9d923b30b14c17b4e8bbce/replicate-1.0.7-py3-none-any.whl", hash = "sha256:667c50a9eb83be17de6278ff89483102b3b50f49a2c7fbcaa2e2b14df13816f9", size = 48626, upload-time = "2025-05-27T11:29:06.801Z" }, +] + [[package]] name = "requests" version = "2.32.5"