Skip to content

Commit 989d8ad

Browse files
committed
feat: Apply Querier pattern in get registry action
1 parent 07203ad commit 989d8ad

File tree

7 files changed

+276
-45
lines changed

7 files changed

+276
-45
lines changed

src/ai/backend/manager/models/container_registry.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import uuid
66
from collections.abc import Sequence
77
from dataclasses import dataclass
8-
from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, Self, cast
8+
from typing import TYPE_CHECKING, Any, Optional, Self, cast
99
from urllib.parse import urlparse
1010

1111
import graphene
@@ -230,33 +230,6 @@ async def get_container_registry_info(
230230

231231
return yarl.URL(url), creds
232232

233-
@classmethod
234-
async def get_known_container_registries(
235-
cls,
236-
session: AsyncSession,
237-
) -> Mapping[str, Mapping[str, yarl.URL]]:
238-
query_stmt = (
239-
sa.select(ContainerRegistryRow)
240-
.options(
241-
load_only(
242-
ContainerRegistryRow.project,
243-
ContainerRegistryRow.registry_name,
244-
ContainerRegistryRow.url,
245-
)
246-
)
247-
.order_by(ContainerRegistryRow.registry_name, ContainerRegistryRow.project)
248-
)
249-
registries = cast(list[ContainerRegistryRow], (await session.scalars(query_stmt)).all())
250-
result: MutableMapping[str, MutableMapping[str, yarl.URL]] = {}
251-
for registry_row in registries:
252-
project = registry_row.project
253-
registry_name = registry_row.registry_name
254-
url = registry_row.url
255-
if project not in result:
256-
result[project] = {}
257-
result[project][registry_name] = yarl.URL(url)
258-
return result
259-
260233
@classmethod
261234
def from_dataclass(cls, data: ContainerRegistryData) -> Self:
262235
instance = cls(

src/ai/backend/manager/repositories/container_registry/db_source/db_source.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from uuid import UUID
66

77
import sqlalchemy as sa
8+
import yarl
89
from sqlalchemy.ext.asyncio import AsyncSession as SASession
910

1011
from ai.backend.common.container_registry import AllowedGroupsModel
@@ -30,6 +31,7 @@
3031
)
3132
from ai.backend.manager.models.image import ImageRow
3233
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine
34+
from ai.backend.manager.repositories.base import Querier, execute_querier
3335

3436
log = BraceStyleAdapter(logging.getLogger(__spec__.name))
3537

@@ -198,24 +200,29 @@ async def mark_images_as_deleted(
198200

199201
await session.execute(update_stmt)
200202

201-
async def fetch_known_registries(self) -> list[ContainerRegistryLocationInfo]:
203+
async def fetch_known_registries(
204+
self, querier: Optional[Querier]
205+
) -> list[ContainerRegistryLocationInfo]:
202206
"""Fetch all known container registries from the database."""
203207
async with self._db.begin_readonly_session() as session:
204-
known_registries_map = await ContainerRegistryRow.get_known_container_registries(
205-
session
208+
query_stmt = sa.select(
209+
ContainerRegistryRow.project,
210+
ContainerRegistryRow.registry_name,
211+
ContainerRegistryRow.url,
212+
sa.func.count().over().label("total_count"),
206213
)
207214

215+
result = await execute_querier(session, query_stmt, querier, ContainerRegistryRow)
216+
rows = result.rows
208217
known_registries: list[ContainerRegistryLocationInfo] = []
209-
for project, registries in known_registries_map.items():
210-
for registry_name, url in registries.items():
211-
if project not in known_registries:
212-
known_registries.append(
213-
ContainerRegistryLocationInfo(
214-
registry_name=registry_name,
215-
project=project,
216-
url=url.human_repr(),
217-
)
218-
)
218+
for row in rows:
219+
known_registries.append(
220+
ContainerRegistryLocationInfo(
221+
registry_name=row.registry_name,
222+
project=row.project,
223+
url=yarl.URL(row.url).human_repr(),
224+
)
225+
)
219226

220227
return known_registries
221228

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from ai.backend.manager.models.container_registry import ContainerRegistryRow
2+
from ai.backend.manager.repositories.base import QueryOrder
3+
4+
5+
class ContainerRegistryOrders:
6+
@staticmethod
7+
def project(ascending: bool = True) -> QueryOrder:
8+
if ascending:
9+
return ContainerRegistryRow.project.asc()
10+
else:
11+
return ContainerRegistryRow.project.desc()
12+
13+
@staticmethod
14+
def registry_name(ascending: bool = True) -> QueryOrder:
15+
if ascending:
16+
return ContainerRegistryRow.registry_name.asc()
17+
else:
18+
return ContainerRegistryRow.registry_name.desc()

src/ai/backend/manager/repositories/container_registry/repository.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from ai.backend.manager.models.container_registry import ContainerRegistryRow
1616
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine
17+
from ai.backend.manager.repositories.base import Querier
1718
from ai.backend.manager.repositories.container_registry.db_source.db_source import (
1819
ContainerRegistryDBSource,
1920
)
@@ -92,8 +93,10 @@ async def clear_images(
9293
return await self._db_source.fetch_by_registry_and_project(registry_name, project)
9394

9495
@container_registry_repository_resilience.apply()
95-
async def get_known_registries(self) -> list[ContainerRegistryLocationInfo]:
96-
return await self._db_source.fetch_known_registries()
96+
async def get_known_registries(
97+
self, querier: Optional[Querier]
98+
) -> list[ContainerRegistryLocationInfo]:
99+
return await self._db_source.fetch_known_registries(querier)
97100

98101
@container_registry_repository_resilience.apply()
99102
async def get_registry_row_for_scanner(

src/ai/backend/manager/services/container_registry/actions/get_container_registries.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55
from ai.backend.manager.data.container_registry.types import (
66
ContainerRegistryLocationInfo,
77
)
8+
from ai.backend.manager.repositories.base import Querier
89
from ai.backend.manager.services.container_registry.actions.base import ContainerRegistryAction
910

1011

1112
@dataclass
1213
class GetContainerRegistriesAction(ContainerRegistryAction):
14+
querier: Optional[Querier] = None
15+
1316
@override
1417
def entity_id(self) -> Optional[str]:
1518
return None

src/ai/backend/manager/services/container_registry/service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ async def load_all_container_registries(
133133
return LoadAllContainerRegistriesActionResult(registries=registries)
134134

135135
async def get_container_registries(
136-
self, _action: GetContainerRegistriesAction
136+
self, action: GetContainerRegistriesAction
137137
) -> GetContainerRegistriesActionResult:
138-
registries = await self._container_registry_repository.get_known_registries()
138+
registries = await self._container_registry_repository.get_known_registries(action.querier)
139139
return GetContainerRegistriesActionResult(registries=registries)

0 commit comments

Comments
 (0)