Skip to content

Commit b346fbd

Browse files
committed
refactor: Make known registry return dataclass, not dict
1 parent 269aa50 commit b346fbd

File tree

8 files changed

+76
-27
lines changed

8 files changed

+76
-27
lines changed

src/ai/backend/manager/api/resource.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,11 @@ async def get_container_registries(request: web.Request) -> web.Response:
392392
GetContainerRegistriesAction()
393393
)
394394
)
395+
response: dict[str, str] = {}
396+
for registry in result.registries:
397+
response[f"{registry.project}/{registry.registry_name}"] = registry.url
395398

396-
return web.json_response(result.registries, status=HTTPStatus.OK)
399+
return web.json_response(response, status=HTTPStatus.OK)
397400

398401

399402
def create_app(

src/ai/backend/manager/data/container_registry/types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,10 @@ def fields_to_update(self) -> dict[str, Any]:
7373
self.ssl_verify.update_dict(to_update, "ssl_verify")
7474
self.extra.update_dict(to_update, "extra")
7575
return to_update
76+
77+
78+
@dataclass
79+
class ContainerRegistryLocationInfo:
80+
project: Optional[str]
81+
registry_name: str
82+
url: str

src/ai/backend/manager/models/container_registry.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,12 +235,16 @@ async def get_known_container_registries(
235235
cls,
236236
session: AsyncSession,
237237
) -> Mapping[str, Mapping[str, yarl.URL]]:
238-
query_stmt = sa.select(ContainerRegistryRow).options(
239-
load_only(
240-
ContainerRegistryRow.project,
241-
ContainerRegistryRow.registry_name,
242-
ContainerRegistryRow.url,
238+
query_stmt = (
239+
sa.select(ContainerRegistryRow)
240+
.options(
241+
load_only(
242+
ContainerRegistryRow.project,
243+
ContainerRegistryRow.registry_name,
244+
ContainerRegistryRow.url,
245+
)
243246
)
247+
.order_by(ContainerRegistryRow.registry_name, ContainerRegistryRow.project)
244248
)
245249
registries = cast(list[ContainerRegistryRow], (await session.scalars(query_stmt)).all())
246250
result: MutableMapping[str, MutableMapping[str, yarl.URL]] = {}

src/ai/backend/manager/repositories/container_registry/db_source/db_source.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ai.backend.manager.data.container_registry.types import (
1313
ContainerRegistryCreator,
1414
ContainerRegistryData,
15+
ContainerRegistryLocationInfo,
1516
ContainerRegistryModifier,
1617
)
1718
from ai.backend.manager.data.image.types import ImageStatus
@@ -197,18 +198,24 @@ async def mark_images_as_deleted(
197198

198199
await session.execute(update_stmt)
199200

200-
async def fetch_known_registries(self) -> dict[str, str]:
201+
async def fetch_known_registries(self) -> list[ContainerRegistryLocationInfo]:
201202
"""Fetch all known container registries from the database."""
202203
async with self._db.begin_readonly_session() as session:
203204
known_registries_map = await ContainerRegistryRow.get_known_container_registries(
204205
session
205206
)
206207

207-
known_registries = {}
208+
known_registries: list[ContainerRegistryLocationInfo] = []
208209
for project, registries in known_registries_map.items():
209210
for registry_name, url in registries.items():
210211
if project not in known_registries:
211-
known_registries[f"{project}/{registry_name}"] = url.human_repr()
212+
known_registries.append(
213+
ContainerRegistryLocationInfo(
214+
registry_name=registry_name,
215+
project=project,
216+
url=url.human_repr(),
217+
)
218+
)
212219

213220
return known_registries
214221

src/ai/backend/manager/repositories/container_registry/repository.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ai.backend.manager.data.container_registry.types import (
1010
ContainerRegistryCreator,
1111
ContainerRegistryData,
12+
ContainerRegistryLocationInfo,
1213
ContainerRegistryModifier,
1314
)
1415
from ai.backend.manager.models.container_registry import ContainerRegistryRow
@@ -91,7 +92,7 @@ async def clear_images(
9192
return await self._db_source.fetch_by_registry_and_project(registry_name, project)
9293

9394
@container_registry_repository_resilience.apply()
94-
async def get_known_registries(self) -> dict[str, str]:
95+
async def get_known_registries(self) -> list[ContainerRegistryLocationInfo]:
9596
return await self._db_source.fetch_known_registries()
9697

9798
@container_registry_repository_resilience.apply()

src/ai/backend/manager/services/container_registry/actions/get_container_registries.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from dataclasses import dataclass
2-
from typing import Any, Optional, override
2+
from typing import Optional, override
33

44
from ai.backend.manager.actions.action import BaseActionResult
5+
from ai.backend.manager.data.container_registry.types import (
6+
ContainerRegistryLocationInfo,
7+
)
58
from ai.backend.manager.services.container_registry.actions.base import ContainerRegistryAction
69

710

@@ -19,7 +22,7 @@ def operation_type(cls) -> str:
1922

2023
@dataclass
2124
class GetContainerRegistriesActionResult(BaseActionResult):
22-
registries: Any
25+
registries: list[ContainerRegistryLocationInfo]
2326

2427
@override
2528
def entity_id(self) -> Optional[str]:

tests/manager/integration/services/container_registry/test_container_registry_integration.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ async def test_get_container_registries_integration(
323323
# For global registry, actually set project to None in the database
324324
registry_id = None
325325
async with database_engine.begin_session() as session:
326-
registry = ContainerRegistryRow(
326+
registry_row = ContainerRegistryRow(
327327
url="https://registry3.example.com",
328328
registry_name="registry3",
329329
type=ContainerRegistryType.DOCKER,
@@ -334,9 +334,9 @@ async def test_get_container_registries_integration(
334334
is_global=True,
335335
extra=None,
336336
)
337-
session.add(registry)
337+
session.add(registry_row)
338338
await session.commit()
339-
registry_id = registry.id
339+
registry_id = registry_row.id
340340

341341
# Action: Get known registries
342342
action = GetContainerRegistriesAction()
@@ -345,12 +345,15 @@ async def test_get_container_registries_integration(
345345
)
346346

347347
# Verify: Should return mapping of project/registry to URL
348-
assert "projectA/registry1" in result.registries
349-
assert result.registries["projectA/registry1"] == "https://registry1.example.com/"
350-
assert "projectB/registry2" in result.registries
351-
assert result.registries["projectB/registry2"] == "https://registry2.example.com/"
348+
response = {}
349+
for registry in result.registries:
350+
response[f"{registry.project}/{registry.registry_name}"] = registry.url
351+
assert "projectA/registry1" in response
352+
assert response["projectA/registry1"] == "https://registry1.example.com/"
353+
assert "projectB/registry2" in response
354+
assert response["projectB/registry2"] == "https://registry2.example.com/"
352355
# Global registry should be included with None prefix
353-
assert "None/registry3" in result.registries
356+
assert "None/registry3" in response
354357

355358
# Cleanup the global registry manually
356359
if registry_id:

tests/manager/services/container_registry/actions/test_get_container_registries.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import pytest
22

3+
from ai.backend.manager.data.container_registry.types import (
4+
ContainerRegistryLocationInfo,
5+
)
36
from ai.backend.manager.services.container_registry.actions.get_container_registries import (
47
GetContainerRegistriesAction,
58
GetContainerRegistriesActionResult,
@@ -19,14 +22,32 @@
1922
"Success Case",
2023
GetContainerRegistriesAction(),
2124
GetContainerRegistriesActionResult(
22-
registries={
23-
# fixtures from fixtures.py
24-
"test_project/registry.example.com": "https://registry.example.com/",
25+
registries=[
2526
# fixtures from example-container-registries-harbor.json
26-
"community/cr.backend.ai": "https://cr.backend.ai/",
27-
"multiarch/cr.backend.ai": "https://cr.backend.ai/",
28-
"stable/cr.backend.ai": "https://cr.backend.ai/",
29-
}
27+
# Sorted by registry_name first (cr.backend.ai < registry.example.com),
28+
# then by project (community < multiarch < stable)
29+
ContainerRegistryLocationInfo(
30+
project="community",
31+
registry_name="cr.backend.ai",
32+
url="https://cr.backend.ai/",
33+
),
34+
ContainerRegistryLocationInfo(
35+
project="multiarch",
36+
registry_name="cr.backend.ai",
37+
url="https://cr.backend.ai/",
38+
),
39+
ContainerRegistryLocationInfo(
40+
project="stable",
41+
registry_name="cr.backend.ai",
42+
url="https://cr.backend.ai/",
43+
),
44+
# fixtures from fixtures.py
45+
ContainerRegistryLocationInfo(
46+
project="test_project",
47+
registry_name="registry.example.com",
48+
url="https://registry.example.com/",
49+
),
50+
],
3051
),
3152
),
3253
],

0 commit comments

Comments
 (0)