Skip to content

Commit 3559aa5

Browse files
authored
refactor: use Depends for settings and session_local (#59)
1 parent 557de5a commit 3559aa5

20 files changed

+205
-120
lines changed

app/shared/celery.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
from celery import Celery
22

3-
from app.shared.settings import settings
43

5-
6-
def get_celery_binding() -> Celery:
7-
celery = Celery(
8-
broker_url=settings.BROKER_URL,
4+
def get_celery_binding(broker_url: str) -> Celery:
5+
return Celery(
6+
broker_url=broker_url,
97
broker_connection_retry=False,
108
broker_connection_retry_on_startup=False,
119
)
12-
13-
return celery

app/shared/db/alembic/env.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from sqlalchemy import engine_from_config, pool
55

66
from app.shared.db.models import Base
7-
from app.shared.settings import settings
7+
from app.shared.settings import Settings
8+
9+
settings = Settings() # type: ignore
810

911
# this is the Alembic Config object, which provides
1012
# access to the values within the .ini file in use.

app/shared/db/base.py

+14-19
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,21 @@
1-
from typing import Any, Generator
1+
from typing import Any
22

3-
from sqlalchemy import create_engine, event
4-
from sqlalchemy.orm import Session, sessionmaker
3+
from sqlalchemy import Engine, create_engine, event
4+
from sqlalchemy.orm import sessionmaker
55

6-
from app.shared.settings import settings
76

8-
engine = create_engine(settings.DATABASE_URI, connect_args={"check_same_thread": False})
7+
def make_engine(database_url: str):
8+
engine = create_engine(database_url, connect_args={"check_same_thread": False})
99

10+
@event.listens_for(engine, "connect")
11+
def set_sqlite_pragma(conn: Any, _: Any) -> None:
12+
cursor = conn.cursor()
13+
cursor.execute("PRAGMA journal_mode=WAL")
14+
cursor.close()
1015

11-
@event.listens_for(engine, "connect")
12-
def set_sqlite_pragma(conn: Any, _: Any) -> None:
13-
cursor = conn.cursor()
14-
cursor.execute("PRAGMA journal_mode=WAL")
15-
cursor.close()
16+
return engine
1617

1718

18-
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
19-
20-
21-
def get_session() -> Generator[Session, None, None]:
22-
session: Session = SessionLocal()
23-
try:
24-
yield session
25-
finally:
26-
session.close()
19+
def make_session_local(engine: Engine):
20+
session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
21+
return session_local

app/shared/logger.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import logging
2+
3+
logging.basicConfig()
4+
5+
logger = logging.getLogger(__name__)
6+
7+
logger.setLevel(logging.INFO)

app/shared/settings.py

-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import sys
2-
31
from pydantic_settings import BaseSettings
42

53

@@ -13,9 +11,3 @@ class Settings(BaseSettings):
1311
TASK_HARD_TIME_LIMIT: int = 4 * 60 * 60
1412

1513
ENABLE_SHARING: bool = False
16-
17-
18-
if "pytest" in sys.modules:
19-
settings = Settings(_env_file=".env.test") # type: ignore
20-
else:
21-
settings = Settings() # type: ignore

app/tests/conftest.py

+28-22
Original file line numberDiff line numberDiff line change
@@ -3,44 +3,57 @@
33
from sqlalchemy_utils import create_database, database_exists, drop_database
44

55
import app.shared.db.models as models
6-
from app.shared.db.base import SessionLocal, engine
7-
from app.shared.settings import settings
6+
from app.shared.db.base import make_engine, make_session_local
7+
from app.shared.settings import Settings
8+
from app.web.injections.db import get_session
9+
from app.web.injections.settings import get_settings
810
from app.web.main import app_factory
911

1012

11-
def pytest_configure() -> None:
12-
if not database_exists(engine.url):
13-
create_database(engine.url)
14-
15-
16-
def pytest_unconfigure() -> None:
17-
if database_exists(engine.url):
18-
drop_database(engine.url)
13+
@pytest.fixture()
14+
def settings():
15+
return Settings(_env_file=".env.test") # type: ignore
1916

2017

2118
@pytest.fixture()
22-
def auth_headers() -> dict[str, str]:
19+
def auth_headers(settings) -> dict[str, str]:
2320
return {"Authorization": f"Bearer {settings.API_SECRET}"}
2421

2522

2623
@pytest.fixture()
27-
def test_db():
24+
def test_db(settings):
25+
engine = make_engine(settings.DATABASE_URI)
26+
27+
if not database_exists(engine.url):
28+
create_database(engine.url)
29+
2830
models.Base.metadata.create_all(engine)
31+
2932
connection = engine.connect()
3033
yield connection
3134
connection.close()
35+
3236
models.Base.metadata.drop_all(bind=engine)
37+
drop_database(engine.url)
3338

3439

3540
@pytest.fixture()
3641
def db_session(test_db):
37-
with SessionLocal(bind=test_db) as session:
42+
session_local = make_session_local(test_db)
43+
with session_local() as session:
3844
yield session
3945

4046

4147
@pytest.fixture()
42-
def client(db_session):
43-
app = app_factory(lambda: db_session)
48+
def app(db_session, settings):
49+
app = app_factory()
50+
app.dependency_overrides[get_settings] = lambda: settings
51+
app.dependency_overrides[get_session] = lambda: db_session
52+
return app
53+
54+
55+
@pytest.fixture()
56+
def client(app):
4457
client = TestClient(app)
4558
return client
4659

@@ -66,10 +79,3 @@ def mock_artifact(db_session, mock_job):
6679
db_session.add(artifact)
6780
db_session.commit()
6881
return artifact
69-
70-
71-
@pytest.fixture()
72-
def sharing_enabled():
73-
settings.ENABLE_SHARING = True
74-
yield
75-
settings.ENABLE_SHARING = False

app/tests/test_api.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from fastapi.testclient import TestClient
2-
31
import app.shared.db.models as models
4-
from app.web.main import app_factory
2+
from app.shared.settings import Settings
3+
from app.web.injections.settings import get_settings
54

65

76
# POST /api/v1/jobs
@@ -69,9 +68,10 @@ def test_get_job_sharing_disabled(client, mock_job):
6968
assert res.status_code == 401
7069

7170

72-
def test_get_job_sharing_enabled(db_session, mock_job, sharing_enabled):
73-
# HACK: delay construction until settings are patched.
74-
client = TestClient(app_factory(lambda: db_session))
71+
def test_get_job_sharing_enabled(client, app, mock_job):
72+
app.dependency_overrides[get_settings] = lambda: Settings(
73+
_env_file=".env.test", ENABLE_SHARING=True # type: ignore
74+
)
7575

7676
res = client.get(
7777
f"/api/v1/jobs/{mock_job.id}",

app/web/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from app.shared.db.base import get_session
21
from app.web.main import app_factory
32

4-
app = app_factory(get_session)
3+
app = app_factory

app/web/injections/__init__.py

Whitespace-only changes.

app/web/injections/db.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from functools import lru_cache
2+
from typing import Generator
3+
4+
from fastapi import Depends
5+
from sqlalchemy.orm import Session
6+
7+
from app.shared.db.base import make_engine, make_session_local
8+
from app.shared.settings import Settings
9+
from app.web.injections.settings import get_settings
10+
11+
12+
@lru_cache
13+
def session_local(database_url: str):
14+
engine = make_engine(database_url)
15+
return make_session_local(engine)
16+
17+
18+
def get_session_local(settings: Settings = Depends(get_settings)):
19+
return session_local(settings.DATABASE_URI)
20+
21+
22+
def get_session(
23+
session_local=Depends(get_session_local),
24+
) -> Generator[Session, None, None]:
25+
with session_local() as session:
26+
yield session

app/web/injections/security.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from hmac import compare_digest
2+
from typing import Annotated
3+
4+
from fastapi import Depends, HTTPException
5+
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
6+
7+
from app.shared.settings import Settings
8+
from app.web.injections.settings import get_settings
9+
10+
11+
def api_key_auth(
12+
credentials: Annotated[
13+
HTTPAuthorizationCredentials, Depends(HTTPBearer(auto_error=False))
14+
],
15+
settings: Annotated[Settings, Depends(get_settings)],
16+
):
17+
validate_credentials(credentials, settings.API_SECRET)
18+
19+
20+
def sharing_auth(
21+
credentials: Annotated[
22+
HTTPAuthorizationCredentials, Depends(HTTPBearer(auto_error=False))
23+
],
24+
settings: Annotated[Settings, Depends(get_settings)],
25+
):
26+
if settings.ENABLE_SHARING:
27+
pass
28+
else:
29+
validate_credentials(credentials, settings.API_SECRET)
30+
31+
32+
def validate_credentials(credentials: HTTPAuthorizationCredentials, secret: str):
33+
# use compare_digest to counter timing attacks.
34+
if (
35+
not credentials
36+
or not secret
37+
or not compare_digest(secret, credentials.credentials)
38+
):
39+
raise HTTPException(status_code=401)

app/web/injections/settings.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from functools import lru_cache
2+
3+
from app.shared.settings import Settings
4+
5+
6+
@lru_cache
7+
def get_settings():
8+
return Settings() # type: ignore

app/web/injections/task_queue.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from functools import lru_cache
2+
3+
from fastapi import Depends
4+
5+
from app.shared.settings import Settings
6+
from app.web.injections.settings import get_settings
7+
from app.web.task_queue import TaskQueue
8+
9+
10+
@lru_cache
11+
def task_queue(broker_url: str):
12+
return TaskQueue(broker_url)
13+
14+
15+
def get_task_queue(settings: Settings = Depends(get_settings)):
16+
return task_queue(settings.BROKER_URL)

app/web/main.py

+14-16
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Annotated, Callable, Generator
1+
from typing import Annotated
22
from uuid import UUID
33

44
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
@@ -7,18 +7,15 @@
77

88
import app.shared.db.models as models
99
import app.web.dtos as dtos
10-
from app.shared.settings import settings
11-
from app.web.security import authenticate_api_key
10+
from app.web.injections.db import get_session
11+
from app.web.injections.security import api_key_auth, sharing_auth
12+
from app.web.injections.task_queue import get_task_queue
1213
from app.web.task_queue import TaskQueue
1314

15+
DatabaseSession = Annotated[Session, Depends(get_session)]
1416

15-
def app_factory(
16-
session_getter: Callable[[], Generator[Session, None, None]]
17-
) -> FastAPI:
18-
DatabaseSession = Annotated[Session, Depends(session_getter)]
19-
20-
task_queue = TaskQueue()
2117

18+
def app_factory():
2219
app = FastAPI(
2320
description=(
2421
"whisperbox-transcribe is an async HTTP wrapper for openai/whisper."
@@ -28,13 +25,13 @@ def app_factory(
2825

2926
api_router = APIRouter(prefix="/api/v1")
3027

31-
@api_router.get("/", response_model=None, status_code=204)
32-
def api_root() -> None:
28+
@api_router.get("/", status_code=204)
29+
def api_root():
3330
return None
3431

3532
@api_router.get(
3633
"/jobs",
37-
dependencies=[Depends(authenticate_api_key)],
34+
dependencies=[Depends(api_key_auth)],
3835
response_model=list[dtos.Job],
3936
summary="Get metadata for all jobs",
4037
)
@@ -52,7 +49,7 @@ def get_jobs(
5249

5350
@api_router.get(
5451
"/jobs/{id}",
55-
dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)],
52+
dependencies=[Depends(sharing_auth)],
5653
response_model=dtos.Job,
5754
summary="Get metadata for one job",
5855
)
@@ -72,7 +69,7 @@ def get_job(
7269

7370
@api_router.get(
7471
"/jobs/{id}/artifacts",
75-
dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)],
72+
dependencies=[Depends(api_key_auth)],
7673
response_model=list[dtos.Artifact],
7774
summary="Get all artifacts for one job",
7875
)
@@ -93,7 +90,7 @@ def get_artifacts_for_job(
9390

9491
@api_router.delete(
9592
"/jobs/{id}",
96-
dependencies=[Depends(authenticate_api_key)],
93+
dependencies=[Depends(sharing_auth)],
9794
status_code=204,
9895
summary="Delete a job with all artifacts",
9996
)
@@ -130,14 +127,15 @@ class PostJobPayload(BaseModel):
130127

131128
@api_router.post(
132129
"/jobs",
133-
dependencies=[Depends(authenticate_api_key)],
130+
dependencies=[Depends(api_key_auth)],
134131
response_model=dtos.Job,
135132
status_code=201,
136133
summary="Enqueue a new job",
137134
)
138135
def create_job(
139136
payload: PostJobPayload,
140137
session: DatabaseSession,
138+
task_queue: Annotated[TaskQueue, Depends(get_task_queue)],
141139
) -> models.Job:
142140
"""
143141
Enqueue a new whisper job for processing.

0 commit comments

Comments
 (0)