From 126be8cd7d0c5a2345286c937522b2902540536c Mon Sep 17 00:00:00 2001 From: Robin Van de Merghel Date: Fri, 13 Jun 2025 13:41:30 +0200 Subject: [PATCH 01/11] feat: Add pilot management: create/delete/patch and query --- .../src/diracx/client/_generated/_client.py | 5 +- .../diracx/client/_generated/aio/_client.py | 5 +- .../_generated/aio/operations/__init__.py | 2 + .../_generated/aio/operations/_operations.py | 586 +++++++++++++++ .../_generated/aio/operations/_patch.py | 2 + .../client/_generated/models/__init__.py | 8 + .../diracx/client/_generated/models/_enums.py | 13 + .../client/_generated/models/_models.py | 199 +++++ .../client/_generated/operations/__init__.py | 2 + .../_generated/operations/_operations.py | 694 ++++++++++++++++++ .../client/_generated/operations/_patch.py | 2 + .../src/diracx/client/patches/pilots/aio.py | 53 ++ .../diracx/client/patches/pilots/common.py | 146 ++++ .../src/diracx/client/patches/pilots/sync.py | 53 ++ diracx-core/src/diracx/core/exceptions.py | 25 +- diracx-core/src/diracx/core/models.py | 38 +- diracx-db/src/diracx/db/sql/__init__.py | 2 +- diracx-db/src/diracx/db/sql/dummy/db.py | 8 +- diracx-db/src/diracx/db/sql/job/db.py | 3 +- .../src/diracx/db/sql/pilot_agents/db.py | 45 -- .../sql/{pilot_agents => pilots}/__init__.py | 0 diracx-db/src/diracx/db/sql/pilots/db.py | 245 +++++++ .../db/sql/{pilot_agents => pilots}/schema.py | 5 +- diracx-db/src/diracx/db/sql/utils/__init__.py | 18 +- .../pilot_agents/test_pilot_agents_db.py | 30 - .../{pilot_agents => pilots}/__init__.py | 0 .../tests/pilots/test_pilot_management.py | 196 +++++ diracx-db/tests/pilots/test_query.py | 300 ++++++++ diracx-db/tests/pilots/utils.py | 151 ++++ diracx-db/tests/test_dummy_db.py | 1 + .../src/diracx/logic/pilots/__init__.py | 0 .../src/diracx/logic/pilots/management.py | 122 +++ diracx-logic/src/diracx/logic/pilots/query.py | 191 +++++ diracx-routers/pyproject.toml | 2 + .../src/diracx/routers/pilots/__init__.py | 13 + .../diracx/routers/pilots/access_policies.py | 125 ++++ .../src/diracx/routers/pilots/management.py | 260 +++++++ .../src/diracx/routers/pilots/query.py | 165 +++++ .../tests/pilots/test_pilot_creation.py | 284 +++++++ diracx-routers/tests/pilots/test_query.py | 414 +++++++++++ docs/dev/explanations/pilots.md | 20 + .../src/gubbins/client/_generated/_client.py | 12 +- .../gubbins/client/_generated/aio/_client.py | 12 +- .../_generated/aio/operations/__init__.py | 2 + .../_generated/aio/operations/_operations.py | 586 +++++++++++++++ .../client/_generated/models/__init__.py | 8 + .../client/_generated/models/_enums.py | 13 + .../client/_generated/models/_models.py | 199 +++++ .../client/_generated/operations/__init__.py | 2 + .../_generated/operations/_operations.py | 694 ++++++++++++++++++ 50 files changed, 5861 insertions(+), 100 deletions(-) create mode 100644 diracx-client/src/diracx/client/patches/pilots/aio.py create mode 100644 diracx-client/src/diracx/client/patches/pilots/common.py create mode 100644 diracx-client/src/diracx/client/patches/pilots/sync.py delete mode 100644 diracx-db/src/diracx/db/sql/pilot_agents/db.py rename diracx-db/src/diracx/db/sql/{pilot_agents => pilots}/__init__.py (100%) create mode 100644 diracx-db/src/diracx/db/sql/pilots/db.py rename diracx-db/src/diracx/db/sql/{pilot_agents => pilots}/schema.py (92%) delete mode 100644 diracx-db/tests/pilot_agents/test_pilot_agents_db.py rename diracx-db/tests/{pilot_agents => pilots}/__init__.py (100%) create mode 100644 diracx-db/tests/pilots/test_pilot_management.py create mode 100644 diracx-db/tests/pilots/test_query.py create mode 100644 diracx-db/tests/pilots/utils.py create mode 100644 diracx-logic/src/diracx/logic/pilots/__init__.py create mode 100644 diracx-logic/src/diracx/logic/pilots/management.py create mode 100644 diracx-logic/src/diracx/logic/pilots/query.py create mode 100644 diracx-routers/src/diracx/routers/pilots/__init__.py create mode 100644 diracx-routers/src/diracx/routers/pilots/access_policies.py create mode 100644 diracx-routers/src/diracx/routers/pilots/management.py create mode 100644 diracx-routers/src/diracx/routers/pilots/query.py create mode 100644 diracx-routers/tests/pilots/test_pilot_creation.py create mode 100644 diracx-routers/tests/pilots/test_query.py create mode 100644 docs/dev/explanations/pilots.md diff --git a/diracx-client/src/diracx/client/_generated/_client.py b/diracx-client/src/diracx/client/_generated/_client.py index aa558f636..9e37d5081 100644 --- a/diracx-client/src/diracx/client/_generated/_client.py +++ b/diracx-client/src/diracx/client/_generated/_client.py @@ -15,7 +15,7 @@ from . import models as _models from ._configuration import DiracConfiguration from ._utils.serialization import Deserializer, Serializer -from .operations import AuthOperations, ConfigOperations, JobsOperations, WellKnownOperations +from .operations import AuthOperations, ConfigOperations, JobsOperations, PilotsOperations, WellKnownOperations class Dirac: # pylint: disable=client-accepts-api-version-keyword @@ -29,6 +29,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype config: _generated.operations.ConfigOperations :ivar jobs: JobsOperations operations :vartype jobs: _generated.operations.JobsOperations + :ivar pilots: PilotsOperations operations + :vartype pilots: _generated.operations.PilotsOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -65,6 +67,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.auth = AuthOperations(self._client, self._config, self._serialize, self._deserialize) self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> HttpResponse: """Runs the network request through the client's chained policies. diff --git a/diracx-client/src/diracx/client/_generated/aio/_client.py b/diracx-client/src/diracx/client/_generated/aio/_client.py index 10cfad884..397b7f989 100644 --- a/diracx-client/src/diracx/client/_generated/aio/_client.py +++ b/diracx-client/src/diracx/client/_generated/aio/_client.py @@ -15,7 +15,7 @@ from .. import models as _models from .._utils.serialization import Deserializer, Serializer from ._configuration import DiracConfiguration -from .operations import AuthOperations, ConfigOperations, JobsOperations, WellKnownOperations +from .operations import AuthOperations, ConfigOperations, JobsOperations, PilotsOperations, WellKnownOperations class Dirac: # pylint: disable=client-accepts-api-version-keyword @@ -29,6 +29,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype config: _generated.aio.operations.ConfigOperations :ivar jobs: JobsOperations operations :vartype jobs: _generated.aio.operations.JobsOperations + :ivar pilots: PilotsOperations operations + :vartype pilots: _generated.aio.operations.PilotsOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -65,6 +67,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.auth = AuthOperations(self._client, self._config, self._serialize, self._deserialize) self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) def send_request( self, request: HttpRequest, *, stream: bool = False, **kwargs: Any diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py b/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py index 10db0c7a9..be02776fc 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py @@ -14,6 +14,7 @@ from ._operations import AuthOperations # type: ignore from ._operations import ConfigOperations # type: ignore from ._operations import JobsOperations # type: ignore +from ._operations import PilotsOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -24,6 +25,7 @@ "AuthOperations", "ConfigOperations", "JobsOperations", + "PilotsOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py index a87316174..82c27ee1b 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py @@ -52,6 +52,12 @@ build_jobs_summary_request, build_jobs_unassign_bulk_jobs_sandboxes_request, build_jobs_unassign_job_sandboxes_request, + build_pilots_add_pilot_stamps_request, + build_pilots_delete_pilots_request, + build_pilots_get_pilot_jobs_request, + build_pilots_search_request, + build_pilots_summary_request, + build_pilots_update_pilot_fields_request, build_well_known_get_installation_metadata_request, build_well_known_get_jwks_request, build_well_known_get_openid_configuration_request, @@ -2356,3 +2362,583 @@ async def submit_jdl_jobs(self, body: Union[List[str], IO[bytes]], **kwargs: Any return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.aio.Dirac`'s + :attr:`pilots` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: AsyncPipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + async def add_pilot_stamps( + self, body: _models.BodyPilotsAddPilotStamps, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Is either a BodyPilotsAddPilotStamps type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsAddPilotStamps") + + _request = build_pilots_add_pilot_stamps_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + async def delete_pilots( + self, + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any + ) -> None: + """Delete Pilots. + + Endpoint to delete a pilot. + + Two features: + + + #. Or you provide pilot_stamps, so you can delete pilots by their stamp + #. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. + + Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. + + :keyword pilot_stamps: Stamps of the pilots we want to delete. Default value is None. + :paramtype pilot_stamps: list[str] + :keyword age_in_days: The number of days that define the maximum age of pilots to be + deleted.Pilots older than this age will be considered for deletion. Default value is None. + :paramtype age_in_days: int + :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is + 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by + default as True to avoid any mistake.This flag is only used for deletion by time. Default value + is False. + :paramtype delete_only_aborted: bool + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[None] = kwargs.pop("cls", None) + + _request = build_pilots_delete_pilots_request( + pilot_stamps=pilot_stamps, + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + async def update_pilot_fields( + self, body: _models.BodyPilotsUpdatePilotFields, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def update_pilot_fields( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def update_pilot_fields( + self, body: Union[_models.BodyPilotsUpdatePilotFields, IO[bytes]], **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Is either a BodyPilotsUpdatePilotFields type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsUpdatePilotFields") + + _request = build_pilots_update_pilot_fields_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @distributed_trace_async + async def get_pilot_jobs( + self, *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any + ) -> List[int]: + """Get Pilot Jobs. + + Endpoint only for admins, to get jobs of a pilot. + + :keyword pilot_stamp: The stamp of the pilot. Default value is None. + :paramtype pilot_stamp: str + :keyword job_id: The ID of the job. Default value is None. + :paramtype job_id: int + :return: list of int + :rtype: list[int] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[List[int]] = kwargs.pop("cls", None) + + _request = build_pilots_get_pilot_jobs_request( + pilot_stamp=pilot_stamp, + job_id=job_id, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("[int]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @overload + async def search( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def search( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def search( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + async def summary( + self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: ~_generated.models.SummaryParams + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def summary(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "SummaryParams") + + _request = build_pilots_summary_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py b/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py index a408e57d2..0c70ce3e9 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py @@ -11,10 +11,12 @@ __all__ = [ "AuthOperations", "JobsOperations", + "PilotsOperations", ] # Add all objects you want publicly available to users at this package level from ....patches.auth.aio import AuthOperations from ....patches.jobs.aio import JobsOperations +from ....patches.pilots.aio import PilotsOperations def patch_sdk(): diff --git a/diracx-client/src/diracx/client/_generated/models/__init__.py b/diracx-client/src/diracx/client/_generated/models/__init__.py index 06de02aab..ae52349c3 100644 --- a/diracx-client/src/diracx/client/_generated/models/__init__.py +++ b/diracx-client/src/diracx/client/_generated/models/__init__.py @@ -16,6 +16,8 @@ BodyAuthGetOidcTokenGrantType, BodyJobsRescheduleJobs, BodyJobsUnassignBulkJobsSandboxes, + BodyPilotsAddPilotStamps, + BodyPilotsUpdatePilotFields, GroupInfo, HTTPValidationError, HeartbeatData, @@ -27,6 +29,7 @@ JobStatusUpdate, Metadata, OpenIDConfiguration, + PilotFieldsMapping, SandboxDownloadResponse, SandboxInfo, SandboxUploadResponse, @@ -52,6 +55,7 @@ from ._enums import ( # type: ignore ChecksumAlgorithm, JobStatus, + PilotStatus, SandboxFormat, SandboxType, ScalarSearchOperator, @@ -67,6 +71,8 @@ "BodyAuthGetOidcTokenGrantType", "BodyJobsRescheduleJobs", "BodyJobsUnassignBulkJobsSandboxes", + "BodyPilotsAddPilotStamps", + "BodyPilotsUpdatePilotFields", "GroupInfo", "HTTPValidationError", "HeartbeatData", @@ -78,6 +84,7 @@ "JobStatusUpdate", "Metadata", "OpenIDConfiguration", + "PilotFieldsMapping", "SandboxDownloadResponse", "SandboxInfo", "SandboxUploadResponse", @@ -100,6 +107,7 @@ "VectorSearchSpecValues", "ChecksumAlgorithm", "JobStatus", + "PilotStatus", "SandboxFormat", "SandboxType", "ScalarSearchOperator", diff --git a/diracx-client/src/diracx/client/_generated/models/_enums.py b/diracx-client/src/diracx/client/_generated/models/_enums.py index 663d9c951..23edf99d3 100644 --- a/diracx-client/src/diracx/client/_generated/models/_enums.py +++ b/diracx-client/src/diracx/client/_generated/models/_enums.py @@ -34,6 +34,19 @@ class JobStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): RESCHEDULED = "Rescheduled" +class PilotStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """PilotStatus.""" + + SUBMITTED = "Submitted" + WAITING = "Waiting" + RUNNING = "Running" + DONE = "Done" + FAILED = "Failed" + DELETED = "Deleted" + ABORTED = "Aborted" + UNKNOWN = "Unknown" + + class SandboxFormat(str, Enum, metaclass=CaseInsensitiveEnumMeta): """SandboxFormat.""" diff --git a/diracx-client/src/diracx/client/_generated/models/_models.py b/diracx-client/src/diracx/client/_generated/models/_models.py index fc909fe5a..8763de15c 100644 --- a/diracx-client/src/diracx/client/_generated/models/_models.py +++ b/diracx-client/src/diracx/client/_generated/models/_models.py @@ -146,6 +146,109 @@ def __init__(self, *, job_ids: List[int], **kwargs: Any) -> None: self.job_ids = job_ids +class BodyPilotsAddPilotStamps(_serialization.Model): + """Body_pilots_add_pilot_stamps. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamps: List of the pilot stamps we want to add to the db. Required. + :vartype pilot_stamps: list[str] + :ivar vo: Pilot virtual organization. Required. + :vartype vo: str + :ivar grid_type: Grid type of the pilots. + :vartype grid_type: str + :ivar grid_site: Pilots grid site. + :vartype grid_site: str + :ivar destination_site: Pilots destination site. + :vartype destination_site: str + :ivar pilot_references: Association of a pilot reference with a pilot stamp. + :vartype pilot_references: dict[str, str] + :ivar pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", "Running", + "Done", "Failed", "Deleted", "Aborted", and "Unknown". + :vartype pilot_status: str or ~_generated.models.PilotStatus + """ + + _validation = { + "pilot_stamps": {"required": True}, + "vo": {"required": True}, + } + + _attribute_map = { + "pilot_stamps": {"key": "pilot_stamps", "type": "[str]"}, + "vo": {"key": "vo", "type": "str"}, + "grid_type": {"key": "grid_type", "type": "str"}, + "grid_site": {"key": "grid_site", "type": "str"}, + "destination_site": {"key": "destination_site", "type": "str"}, + "pilot_references": {"key": "pilot_references", "type": "{str}"}, + "pilot_status": {"key": "pilot_status", "type": "str"}, + } + + def __init__( + self, + *, + pilot_stamps: List[str], + vo: str, + grid_type: str = "Dirac", + grid_site: str = "Unknown", + destination_site: str = "NotAssigned", + pilot_references: Optional[Dict[str, str]] = None, + pilot_status: Optional[Union[str, "_models.PilotStatus"]] = None, + **kwargs: Any + ) -> None: + """ + :keyword pilot_stamps: List of the pilot stamps we want to add to the db. Required. + :paramtype pilot_stamps: list[str] + :keyword vo: Pilot virtual organization. Required. + :paramtype vo: str + :keyword grid_type: Grid type of the pilots. + :paramtype grid_type: str + :keyword grid_site: Pilots grid site. + :paramtype grid_site: str + :keyword destination_site: Pilots destination site. + :paramtype destination_site: str + :keyword pilot_references: Association of a pilot reference with a pilot stamp. + :paramtype pilot_references: dict[str, str] + :keyword pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", + "Running", "Done", "Failed", "Deleted", "Aborted", and "Unknown". + :paramtype pilot_status: str or ~_generated.models.PilotStatus + """ + super().__init__(**kwargs) + self.pilot_stamps = pilot_stamps + self.vo = vo + self.grid_type = grid_type + self.grid_site = grid_site + self.destination_site = destination_site + self.pilot_references = pilot_references + self.pilot_status = pilot_status + + +class BodyPilotsUpdatePilotFields(_serialization.Model): + """Body_pilots_update_pilot_fields. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamps_to_fields_mapping: (pilot_stamp, pilot_fields) mapping to change. Required. + :vartype pilot_stamps_to_fields_mapping: list[~_generated.models.PilotFieldsMapping] + """ + + _validation = { + "pilot_stamps_to_fields_mapping": {"required": True}, + } + + _attribute_map = { + "pilot_stamps_to_fields_mapping": {"key": "pilot_stamps_to_fields_mapping", "type": "[PilotFieldsMapping]"}, + } + + def __init__(self, *, pilot_stamps_to_fields_mapping: List["_models.PilotFieldsMapping"], **kwargs: Any) -> None: + """ + :keyword pilot_stamps_to_fields_mapping: (pilot_stamp, pilot_fields) mapping to change. + Required. + :paramtype pilot_stamps_to_fields_mapping: list[~_generated.models.PilotFieldsMapping] + """ + super().__init__(**kwargs) + self.pilot_stamps_to_fields_mapping = pilot_stamps_to_fields_mapping + + class GroupInfo(_serialization.Model): """GroupInfo. @@ -886,6 +989,102 @@ def __init__( self.code_challenge_methods_supported = code_challenge_methods_supported +class PilotFieldsMapping(_serialization.Model): + """All the fields that a user can modify on a Pilot (except PilotStamp). + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamp: Pilotstamp. Required. + :vartype pilot_stamp: str + :ivar status_reason: Statusreason. + :vartype status_reason: str + :ivar status: PilotStatus. Known values are: "Submitted", "Waiting", "Running", "Done", + "Failed", "Deleted", "Aborted", and "Unknown". + :vartype status: str or ~_generated.models.PilotStatus + :ivar bench_mark: Benchmark. + :vartype bench_mark: float + :ivar destination_site: Destinationsite. + :vartype destination_site: str + :ivar queue: Queue. + :vartype queue: str + :ivar grid_site: Gridsite. + :vartype grid_site: str + :ivar grid_type: Gridtype. + :vartype grid_type: str + :ivar accounting_sent: Accountingsent. + :vartype accounting_sent: bool + :ivar current_job_id: Currentjobid. + :vartype current_job_id: int + """ + + _validation = { + "pilot_stamp": {"required": True}, + } + + _attribute_map = { + "pilot_stamp": {"key": "PilotStamp", "type": "str"}, + "status_reason": {"key": "StatusReason", "type": "str"}, + "status": {"key": "Status", "type": "str"}, + "bench_mark": {"key": "BenchMark", "type": "float"}, + "destination_site": {"key": "DestinationSite", "type": "str"}, + "queue": {"key": "Queue", "type": "str"}, + "grid_site": {"key": "GridSite", "type": "str"}, + "grid_type": {"key": "GridType", "type": "str"}, + "accounting_sent": {"key": "AccountingSent", "type": "bool"}, + "current_job_id": {"key": "CurrentJobID", "type": "int"}, + } + + def __init__( + self, + *, + pilot_stamp: str, + status_reason: Optional[str] = None, + status: Optional[Union[str, "_models.PilotStatus"]] = None, + bench_mark: Optional[float] = None, + destination_site: Optional[str] = None, + queue: Optional[str] = None, + grid_site: Optional[str] = None, + grid_type: Optional[str] = None, + accounting_sent: Optional[bool] = None, + current_job_id: Optional[int] = None, + **kwargs: Any + ) -> None: + """ + :keyword pilot_stamp: Pilotstamp. Required. + :paramtype pilot_stamp: str + :keyword status_reason: Statusreason. + :paramtype status_reason: str + :keyword status: PilotStatus. Known values are: "Submitted", "Waiting", "Running", "Done", + "Failed", "Deleted", "Aborted", and "Unknown". + :paramtype status: str or ~_generated.models.PilotStatus + :keyword bench_mark: Benchmark. + :paramtype bench_mark: float + :keyword destination_site: Destinationsite. + :paramtype destination_site: str + :keyword queue: Queue. + :paramtype queue: str + :keyword grid_site: Gridsite. + :paramtype grid_site: str + :keyword grid_type: Gridtype. + :paramtype grid_type: str + :keyword accounting_sent: Accountingsent. + :paramtype accounting_sent: bool + :keyword current_job_id: Currentjobid. + :paramtype current_job_id: int + """ + super().__init__(**kwargs) + self.pilot_stamp = pilot_stamp + self.status_reason = status_reason + self.status = status + self.bench_mark = bench_mark + self.destination_site = destination_site + self.queue = queue + self.grid_site = grid_site + self.grid_type = grid_type + self.accounting_sent = accounting_sent + self.current_job_id = current_job_id + + class SandboxDownloadResponse(_serialization.Model): """SandboxDownloadResponse. diff --git a/diracx-client/src/diracx/client/_generated/operations/__init__.py b/diracx-client/src/diracx/client/_generated/operations/__init__.py index 10db0c7a9..be02776fc 100644 --- a/diracx-client/src/diracx/client/_generated/operations/__init__.py +++ b/diracx-client/src/diracx/client/_generated/operations/__init__.py @@ -14,6 +14,7 @@ from ._operations import AuthOperations # type: ignore from ._operations import ConfigOperations # type: ignore from ._operations import JobsOperations # type: ignore +from ._operations import PilotsOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -24,6 +25,7 @@ "AuthOperations", "ConfigOperations", "JobsOperations", + "PilotsOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/diracx-client/src/diracx/client/_generated/operations/_operations.py b/diracx-client/src/diracx/client/_generated/operations/_operations.py index b8800ca84..c682d2a3a 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/operations/_operations.py @@ -586,6 +586,124 @@ def build_jobs_submit_jdl_jobs_request(**kwargs: Any) -> HttpRequest: return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) +def build_pilots_add_pilot_stamps_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + +def build_pilots_delete_pilots_request( + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any +) -> HttpRequest: + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + # Construct URL + _url = "/api/pilots/" + + # Construct parameters + if pilot_stamps is not None: + _params["pilot_stamps"] = _SERIALIZER.query("pilot_stamps", pilot_stamps, "[str]") + if age_in_days is not None: + _params["age_in_days"] = _SERIALIZER.query("age_in_days", age_in_days, "int") + if delete_only_aborted is not None: + _params["delete_only_aborted"] = _SERIALIZER.query("delete_only_aborted", delete_only_aborted, "bool") + + return HttpRequest(method="DELETE", url=_url, params=_params, **kwargs) + + +def build_pilots_update_pilot_fields_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + # Construct URL + _url = "/api/pilots/metadata" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + + return HttpRequest(method="PATCH", url=_url, headers=_headers, **kwargs) + + +def build_pilots_get_pilot_jobs_request( + *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/jobs" + + # Construct parameters + if pilot_stamp is not None: + _params["pilot_stamp"] = _SERIALIZER.query("pilot_stamp", pilot_stamp, "str") + if job_id is not None: + _params["job_id"] = _SERIALIZER.query("job_id", job_id, "int") + + # Construct headers + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_pilots_search_request(*, page: int = 1, per_page: int = 100, **kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/search" + + # Construct parameters + if page is not None: + _params["page"] = _SERIALIZER.query("page", page, "int") + if per_page is not None: + _params["per_page"] = _SERIALIZER.query("per_page", per_page, "int") + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_pilots_summary_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/summary" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + class WellKnownOperations: """ .. warning:: @@ -2875,3 +2993,579 @@ def submit_jdl_jobs(self, body: Union[List[str], IO[bytes]], **kwargs: Any) -> L return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.Dirac`'s + :attr:`pilots` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: PipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + def add_pilot_stamps( + self, body: _models.BodyPilotsAddPilotStamps, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Is either a BodyPilotsAddPilotStamps type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsAddPilotStamps") + + _request = build_pilots_add_pilot_stamps_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + def delete_pilots( # pylint: disable=inconsistent-return-statements + self, + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any + ) -> None: + """Delete Pilots. + + Endpoint to delete a pilot. + + Two features: + + + #. Or you provide pilot_stamps, so you can delete pilots by their stamp + #. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. + + Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. + + :keyword pilot_stamps: Stamps of the pilots we want to delete. Default value is None. + :paramtype pilot_stamps: list[str] + :keyword age_in_days: The number of days that define the maximum age of pilots to be + deleted.Pilots older than this age will be considered for deletion. Default value is None. + :paramtype age_in_days: int + :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is + 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by + default as True to avoid any mistake.This flag is only used for deletion by time. Default value + is False. + :paramtype delete_only_aborted: bool + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[None] = kwargs.pop("cls", None) + + _request = build_pilots_delete_pilots_request( + pilot_stamps=pilot_stamps, + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + def update_pilot_fields( + self, body: _models.BodyPilotsUpdatePilotFields, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def update_pilot_fields(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def update_pilot_fields( # pylint: disable=inconsistent-return-statements + self, body: Union[_models.BodyPilotsUpdatePilotFields, IO[bytes]], **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Is either a BodyPilotsUpdatePilotFields type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsUpdatePilotFields") + + _request = build_pilots_update_pilot_fields_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @distributed_trace + def get_pilot_jobs( + self, *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any + ) -> List[int]: + """Get Pilot Jobs. + + Endpoint only for admins, to get jobs of a pilot. + + :keyword pilot_stamp: The stamp of the pilot. Default value is None. + :paramtype pilot_stamp: str + :keyword job_id: The ID of the job. Default value is None. + :paramtype job_id: int + :return: list of int + :rtype: list[int] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[List[int]] = kwargs.pop("cls", None) + + _request = build_pilots_get_pilot_jobs_request( + pilot_stamp=pilot_stamp, + job_id=job_id, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("[int]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @overload + def search( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def search( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def search( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + def summary(self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: ~_generated.models.SummaryParams + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def summary(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "SummaryParams") + + _request = build_pilots_summary_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/diracx-client/src/diracx/client/_generated/operations/_patch.py b/diracx-client/src/diracx/client/_generated/operations/_patch.py index b7b8c67fa..b14e98b84 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_patch.py +++ b/diracx-client/src/diracx/client/_generated/operations/_patch.py @@ -11,10 +11,12 @@ __all__ = [ "AuthOperations", "JobsOperations", + "PilotsOperations", ] # Add all objects you want publicly available to users at this package level from ...patches.auth.sync import AuthOperations from ...patches.jobs.sync import JobsOperations +from ...patches.pilots.sync import PilotsOperations def patch_sdk(): diff --git a/diracx-client/src/diracx/client/patches/pilots/aio.py b/diracx-client/src/diracx/client/patches/pilots/aio.py new file mode 100644 index 000000000..ac533a67c --- /dev/null +++ b/diracx-client/src/diracx/client/patches/pilots/aio.py @@ -0,0 +1,53 @@ +"""Patches for the autorest-generated pilots client. + +This file can be used to customize the generated code for the pilots client. +When adding new classes to this file, make sure to also add them to the +__all__ list in the corresponding file in the patches directory. +""" + +from __future__ import annotations + +__all__ = [ + "PilotsOperations", +] + +from typing import Any, Unpack + +from azure.core.tracing.decorator_async import distributed_trace_async + +from ..._generated.aio.operations._operations import PilotsOperations as _PilotsOperations +from .common import ( + make_search_body, + make_summary_body, + make_add_pilot_stamps_body, + make_update_pilot_fields_body, + SearchKwargs, + SummaryKwargs, + AddPilotStampsKwargs, + UpdatePilotFieldsKwargs +) + +# We're intentionally ignoring overrides here because we want to change the interface. +# mypy: disable-error-code=override + + +class PilotsOperations(_PilotsOperations): + @distributed_trace_async + async def search(self, **kwargs: Unpack[SearchKwargs]) -> list[dict[str, Any]]: + """TODO""" + return await super().search(**make_search_body(**kwargs)) + + @distributed_trace_async + async def summary(self, **kwargs: Unpack[SummaryKwargs]) -> list[dict[str, Any]]: + """TODO""" + return await super().summary(**make_summary_body(**kwargs)) + + @distributed_trace_async + async def add_pilot_stamps(self, **kwargs: Unpack[AddPilotStampsKwargs]) -> None: + """TODO""" + return await super().add_pilot_stamps(**make_add_pilot_stamps_body(**kwargs)) + + @distributed_trace_async + async def update_pilot_fields(self, **kwargs: Unpack[UpdatePilotFieldsKwargs]) -> None: + """TODO""" + return await super().update_pilot_fields(**make_update_pilot_fields_body(**kwargs)) diff --git a/diracx-client/src/diracx/client/patches/pilots/common.py b/diracx-client/src/diracx/client/patches/pilots/common.py new file mode 100644 index 000000000..3f5ec8c4b --- /dev/null +++ b/diracx-client/src/diracx/client/patches/pilots/common.py @@ -0,0 +1,146 @@ +"""Utilities which are common to the sync and async pilots operator patches.""" + +from __future__ import annotations + +__all__ = [ + "make_search_body", + "SearchKwargs", + "make_summary_body", + "SummaryKwargs", + "AddPilotStampsKwargs", + "make_add_pilot_stamps_body", + "UpdatePilotFieldsKwargs", + "make_update_pilot_fields_body" +] + +import json +from io import BytesIO +from typing import Any, IO, TypedDict, Unpack, cast, Literal + +from diracx.core.models import SearchSpec, PilotStatus, PilotFieldsMapping + + +class ResponseExtra(TypedDict, total=False): + content_type: str + headers: dict[str, str] + params: dict[str, str] + cls: Any + + +# ------------------ Search ------------------ +class SearchBody(TypedDict, total=False): + parameters: list[str] | None + search: list[SearchSpec] | None + sort: list[str] | None + + +class SearchExtra(ResponseExtra, total=False): + page: int + per_page: int + + +class SearchKwargs(SearchBody, SearchExtra): ... + + +class UnderlyingSearchArgs(ResponseExtra, total=False): + # FIXME: The autorest-generated has a bug that it expected IO[bytes] despite + # the code being generated to support IO[bytes] | bytes. + body: IO[bytes] + + +def make_search_body(**kwargs: Unpack[SearchKwargs]) -> UnderlyingSearchArgs: + body: SearchBody = {} + for key in SearchBody.__optional_keys__: + if key not in kwargs: + continue + key = cast(Literal["parameters", "search", "sort"], key) + value = kwargs.pop(key) + if value is not None: + body[key] = value + result: UnderlyingSearchArgs = {"body": BytesIO(json.dumps(body).encode("utf-8"))} + result.update(cast(SearchExtra, kwargs)) + return result + +# ------------------ Summary ------------------ + +class SummaryBody(TypedDict, total=False): + grouping: list[str] + search: list[str] + + +class SummaryKwargs(SummaryBody, ResponseExtra): ... + + +class UnderlyingSummaryArgs(ResponseExtra, total=False): + # FIXME: The autorest-generated has a bug that it expected IO[bytes] despite + # the code being generated to support IO[bytes] | bytes. + body: IO[bytes] + + +def make_summary_body(**kwargs: Unpack[SummaryKwargs]) -> UnderlyingSummaryArgs: + body: SummaryBody = {} + for key in SummaryBody.__optional_keys__: + if key not in kwargs: + continue + key = cast(Literal["grouping", "search"], key) + value = kwargs.pop(key) + if value is not None: + body[key] = value + result: UnderlyingSummaryArgs = {"body": BytesIO(json.dumps(body).encode("utf-8"))} + result.update(cast(ResponseExtra, kwargs)) + return result + +# ------------------ AddPilotStamps ------------------ + +class AddPilotStampsBody(TypedDict, total=False): + pilot_stamps: list[str] + grid_type: str + grid_site: str + pilot_references: dict[str, str] + pilot_status: PilotStatus + vo: str + +class AddPilotStampsKwargs(AddPilotStampsBody, ResponseExtra): ... + +class UnderlyingAddPilotStampsArgs(ResponseExtra, total=False): + # FIXME: The autorest-generated has a bug that it expected IO[bytes] despite + # the code being generated to support IO[bytes] | bytes. + body: IO[bytes] + +def make_add_pilot_stamps_body(**kwargs: Unpack[AddPilotStampsKwargs]) -> UnderlyingAddPilotStampsArgs: + body: AddPilotStampsBody = {} + for key in AddPilotStampsBody.__optional_keys__: + if key not in kwargs: + continue + key = cast(Literal["pilot_stamps", "grid_type", "grid_site", "pilot_references", "pilot_status", "vo"], key) + value = kwargs.pop(key) + if value is not None: + body[key] = value + result: UnderlyingAddPilotStampsArgs = {"body": BytesIO(json.dumps(body).encode("utf-8"))} + result.update(cast(ResponseExtra, kwargs)) + return result + +# ------------------ UpdatePilotFields ------------------ + +class UpdatePilotFieldsBody(TypedDict, total=False): + pilot_stamps_to_fields_mapping: list[PilotFieldsMapping] + +class UpdatePilotFieldsKwargs(UpdatePilotFieldsBody, ResponseExtra): ... + +class UnderlyingUpdatePilotFields(ResponseExtra, total=False): + # FIXME: The autorest-generated has a bug that it expected IO[bytes] despite + # the code being generated to support IO[bytes] | bytes. + body: IO[bytes] + +def make_update_pilot_fields_body(**kwargs: Unpack[UpdatePilotFieldsKwargs]) -> UnderlyingUpdatePilotFields: + body: UpdatePilotFieldsBody = {} + for key in UpdatePilotFieldsBody.__optional_keys__: + if key not in kwargs: + continue + key = cast(Literal["pilot_stamps_to_fields_mapping"], key) + value = kwargs.pop(key) + if value is not None: + body[key] = value + result: UnderlyingUpdatePilotFields = {"body": BytesIO(json.dumps(body).encode("utf-8"))} + result.update(cast(ResponseExtra, kwargs)) + return result diff --git a/diracx-client/src/diracx/client/patches/pilots/sync.py b/diracx-client/src/diracx/client/patches/pilots/sync.py new file mode 100644 index 000000000..744cee161 --- /dev/null +++ b/diracx-client/src/diracx/client/patches/pilots/sync.py @@ -0,0 +1,53 @@ +"""Patches for the autorest-generated pilots client. + +This file can be used to customize the generated code for the pilots client. +When adding new classes to this file, make sure to also add them to the +__all__ list in the corresponding file in the patches directory. +""" + +from __future__ import annotations + +__all__ = [ + "PilotsOperations", +] + +from typing import Any, Unpack + +from azure.core.tracing.decorator import distributed_trace + +from ..._generated.operations._operations import PilotsOperations as _PilotsOperations +from .common import ( + make_search_body, + make_summary_body, + make_add_pilot_stamps_body, + make_update_pilot_fields_body, + SearchKwargs, + SummaryKwargs, + AddPilotStampsKwargs, + UpdatePilotFieldsKwargs +) + +# We're intentionally ignoring overrides here because we want to change the interface. +# mypy: disable-error-code=override + + +class PilotsOperations(_PilotsOperations): + @distributed_trace + def search(self, **kwargs: Unpack[SearchKwargs]) -> list[dict[str, Any]]: + """TODO""" + return super().search(**make_search_body(**kwargs)) + + @distributed_trace + def summary(self, **kwargs: Unpack[SummaryKwargs]) -> list[dict[str, Any]]: + """TODO""" + return super().summary(**make_summary_body(**kwargs)) + + @distributed_trace + def add_pilot_stamps(self, **kwargs: Unpack[AddPilotStampsKwargs]) -> None: + """TODO""" + return super().add_pilot_stamps(**make_add_pilot_stamps_body(**kwargs)) + + @distributed_trace + def update_pilot_fields(self, **kwargs: Unpack[UpdatePilotFieldsKwargs]) -> None: + """TODO""" + return super().update_pilot_fields(**make_update_pilot_fields_body(**kwargs)) diff --git a/diracx-core/src/diracx/core/exceptions.py b/diracx-core/src/diracx/core/exceptions.py index 54d7c240d..19d8d5a41 100644 --- a/diracx-core/src/diracx/core/exceptions.py +++ b/diracx-core/src/diracx/core/exceptions.py @@ -15,6 +15,7 @@ class DiracError(RuntimeError): def __init__(self, detail: str = "Unknown"): self.detail = detail + super().__init__(detail) class AuthorizationError(DiracError): ... @@ -49,19 +50,19 @@ class InvalidQueryError(DiracError): class TokenNotFoundError(DiracError): - def __init__(self, jti: str, detail: str | None = None): + def __init__(self, jti: str, detail: str = ""): self.jti: str = jti super().__init__(f"Token {jti} not found" + (f" ({detail})" if detail else "")) class JobNotFoundError(DiracError): - def __init__(self, job_id: int, detail: str | None = None): + def __init__(self, job_id: int, detail: str = ""): self.job_id: int = job_id super().__init__(f"Job {job_id} not found" + (f" ({detail})" if detail else "")) class SandboxNotFoundError(DiracError): - def __init__(self, pfn: str, se_name: str, detail: str | None = None): + def __init__(self, pfn: str, se_name: str, detail: str = ""): self.pfn: str = pfn self.se_name: str = se_name super().__init__( @@ -71,7 +72,7 @@ def __init__(self, pfn: str, se_name: str, detail: str | None = None): class SandboxAlreadyAssignedError(DiracError): - def __init__(self, pfn: str, se_name: str, detail: str | None = None): + def __init__(self, pfn: str, se_name: str, detail: str = ""): self.pfn: str = pfn self.se_name: str = se_name super().__init__( @@ -81,7 +82,7 @@ def __init__(self, pfn: str, se_name: str, detail: str | None = None): class SandboxAlreadyInsertedError(DiracError): - def __init__(self, pfn: str, se_name: str, detail: str | None = None): + def __init__(self, pfn: str, se_name: str, detail: str = ""): self.pfn: str = pfn self.se_name: str = se_name super().__init__( @@ -91,7 +92,7 @@ def __init__(self, pfn: str, se_name: str, detail: str | None = None): class JobError(DiracError): - def __init__(self, job_id, detail: str | None = None): + def __init__(self, job_id, detail: str = ""): self.job_id: int = job_id super().__init__( f"Error concerning job {job_id}" + (f" ({detail})" if detail else "") @@ -100,3 +101,15 @@ def __init__(self, job_id, detail: str | None = None): class NotReadyError(DiracError): """Tried to access a value which is asynchronously loaded but not yet available.""" + + +class PilotNotFoundError(DiracError): + """At least one pilot is not found.""" + + +class PilotAlreadyExistsError(DiracError): + """At least one pilot already exists, we avoid collitions.""" + + +class PilotAlreadyAssociatedWithJobError(DiracError): + """We can't associate a pilot with the same job twice.""" diff --git a/diracx-core/src/diracx/core/models.py b/diracx-core/src/diracx/core/models.py index bacecd5ad..18144fc38 100644 --- a/diracx-core/src/diracx/core/models.py +++ b/diracx-core/src/diracx/core/models.py @@ -7,7 +7,7 @@ from datetime import datetime from enum import StrEnum -from typing import Literal +from typing import Literal, Optional from pydantic import BaseModel, Field from typing_extensions import TypedDict @@ -31,7 +31,7 @@ class VectorSearchOperator(StrEnum): class ScalarSearchSpec(TypedDict): parameter: str operator: ScalarSearchOperator - value: str | int + value: str | int | datetime class VectorSearchSpec(TypedDict): @@ -325,3 +325,37 @@ class JobCommand(BaseModel): job_id: int command: Literal["Kill"] arguments: str | None = None + + +class PilotFieldsMapping(BaseModel, extra="forbid"): + """All the fields that a user can modify on a Pilot (except PilotStamp).""" + + PilotStamp: str + StatusReason: Optional[str] = None + Status: Optional[PilotStatus] = None + BenchMark: Optional[float] = None + DestinationSite: Optional[str] = None + Queue: Optional[str] = None + GridSite: Optional[str] = None + GridType: Optional[str] = None + AccountingSent: Optional[bool] = None + CurrentJobID: Optional[int] = None + + +class PilotStatus(StrEnum): + #: The pilot has been generated and is transferred to a remote site: + SUBMITTED = "Submitted" + #: The pilot is waiting for a computing resource in a batch queue: + WAITING = "Waiting" + #: The pilot is running a payload on a worker node: + RUNNING = "Running" + #: The pilot finished its execution: + DONE = "Done" + #: The pilot execution failed: + FAILED = "Failed" + #: The pilot was deleted: + DELETED = "Deleted" + #: The pilot execution was aborted: + ABORTED = "Aborted" + #: Cannot get information about the pilot status: + UNKNOWN = "Unknown" diff --git a/diracx-db/src/diracx/db/sql/__init__.py b/diracx-db/src/diracx/db/sql/__init__.py index 3be3af8a3..e2f141ad5 100644 --- a/diracx-db/src/diracx/db/sql/__init__.py +++ b/diracx-db/src/diracx/db/sql/__init__.py @@ -12,6 +12,6 @@ from .auth.db import AuthDB from .job.db import JobDB from .job_logging.db import JobLoggingDB -from .pilot_agents.db import PilotAgentsDB +from .pilots.db import PilotAgentsDB from .sandbox_metadata.db import SandboxMetadataDB from .task_queue.db import TaskQueueDB diff --git a/diracx-db/src/diracx/db/sql/dummy/db.py b/diracx-db/src/diracx/db/sql/dummy/db.py index 5735b43bb..12b0d719c 100644 --- a/diracx-db/src/diracx/db/sql/dummy/db.py +++ b/diracx-db/src/diracx/db/sql/dummy/db.py @@ -3,6 +3,7 @@ from sqlalchemy import insert from uuid_utils import UUID +from diracx.core.models import SearchSpec from diracx.db.sql.utils import BaseSQLDB from .schema import Base as DummyDBBase @@ -21,8 +22,11 @@ class DummyDB(BaseSQLDB): # This needs to be here for the BaseSQLDB to create the engine metadata = DummyDBBase.metadata - async def summary(self, group_by, search) -> list[dict[str, str | int]]: - return await self._summary(Cars, group_by, search) + async def summary( + self, group_by: list[str], search: list[SearchSpec] + ) -> list[dict[str, str | int]]: + """Get a summary of the pilots.""" + return await self._summary(table=Cars, group_by=group_by, search=search) async def insert_owner(self, name: str) -> int: stmt = insert(Owners).values(name=name) diff --git a/diracx-db/src/diracx/db/sql/job/db.py b/diracx-db/src/diracx/db/sql/job/db.py index 01cdb83a1..40b39f33b 100644 --- a/diracx-db/src/diracx/db/sql/job/db.py +++ b/diracx-db/src/diracx/db/sql/job/db.py @@ -13,8 +13,7 @@ from diracx.core.exceptions import InvalidQueryError from diracx.core.models import JobCommand, SearchSpec, SortSpec -from ..utils import BaseSQLDB, _get_columns -from ..utils.functions import utcnow +from ..utils import BaseSQLDB, _get_columns, utcnow from .schema import ( HeartBeatLoggingInfo, InputData, diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/db.py b/diracx-db/src/diracx/db/sql/pilot_agents/db.py deleted file mode 100644 index 954f081b1..000000000 --- a/diracx-db/src/diracx/db/sql/pilot_agents/db.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -from datetime import datetime, timezone - -from sqlalchemy import insert - -from ..utils import BaseSQLDB -from .schema import PilotAgents, PilotAgentsDBBase - - -class PilotAgentsDB(BaseSQLDB): - """PilotAgentsDB class is a front-end to the PilotAgents Database.""" - - metadata = PilotAgentsDBBase.metadata - - async def add_pilot_references( - self, - pilot_ref: list[str], - vo: str, - grid_type: str = "DIRAC", - pilot_stamps: dict | None = None, - ) -> None: - if pilot_stamps is None: - pilot_stamps = {} - - now = datetime.now(tz=timezone.utc) - - # Prepare the list of dictionaries for bulk insertion - values = [ - { - "PilotJobReference": ref, - "VO": vo, - "GridType": grid_type, - "SubmissionTime": now, - "LastUpdateTime": now, - "Status": "Submitted", - "PilotStamp": pilot_stamps.get(ref, ""), - } - for ref in pilot_ref - ] - - # Insert multiple rows in a single execute call - stmt = insert(PilotAgents).values(values) - await self.conn.execute(stmt) - return diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/__init__.py b/diracx-db/src/diracx/db/sql/pilots/__init__.py similarity index 100% rename from diracx-db/src/diracx/db/sql/pilot_agents/__init__.py rename to diracx-db/src/diracx/db/sql/pilots/__init__.py diff --git a/diracx-db/src/diracx/db/sql/pilots/db.py b/diracx-db/src/diracx/db/sql/pilots/db.py new file mode 100644 index 000000000..0bfb32e07 --- /dev/null +++ b/diracx-db/src/diracx/db/sql/pilots/db.py @@ -0,0 +1,245 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +from sqlalchemy import bindparam +from sqlalchemy.exc import IntegrityError +from sqlalchemy.sql import delete, insert, update + +from diracx.core.exceptions import ( + PilotAlreadyAssociatedWithJobError, + PilotNotFoundError, +) +from diracx.core.models import ( + PilotFieldsMapping, + PilotStatus, + SearchSpec, + SortSpec, +) + +from ..utils import ( + BaseSQLDB, +) +from .schema import ( + JobToPilotMapping, + PilotAgents, + PilotAgentsDBBase, + PilotOutput, +) + + +class PilotAgentsDB(BaseSQLDB): + """PilotAgentsDB class is a front-end to the PilotAgents Database.""" + + metadata = PilotAgentsDBBase.metadata + + # ----------------------------- Insert Functions ----------------------------- + + async def add_pilots( + self, + pilot_stamps: list[str], + vo: str, + grid_type: str = "DIRAC", + grid_site: str = "Unknown", + destination_site: str = "NotAssigned", + pilot_references: dict[str, str] | None = None, + status: str = PilotStatus.SUBMITTED, + ): + """Bulk add pilots in the DB. + + If we can't find a pilot_reference associated with a stamp, we take the stamp by default. + """ + if pilot_references is None: + pilot_references = {} + + now = datetime.now(tz=timezone.utc) + + # Prepare the list of dictionaries for bulk insertion + values = [ + { + "PilotJobReference": pilot_references.get(stamp, stamp), + "VO": vo, + "GridType": grid_type, + "GridSite": grid_site, + "DestinationSite": destination_site, + "SubmissionTime": now, + "LastUpdateTime": now, + "Status": status, + "PilotStamp": stamp, + } + for stamp in pilot_stamps + ] + + # Insert multiple rows in a single execute call and use 'returning' to get primary keys + stmt = insert(PilotAgents).values(values) # Assuming 'id' is the primary key + + await self.conn.execute(stmt) + + async def add_jobs_to_pilot(self, job_to_pilot_mapping: list[dict[str, Any]]): + """Associate a pilot with jobs. + + job_to_pilot_mapping format: + ```py + job_to_pilot_mapping = [ + {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} + ] + ``` + + Raises: + - PilotNotFoundError if a pilot_id is not associated with a pilot. + - PilotAlreadyAssociatedWithJobError if the pilot is already associated with one of the given jobs. + - NotImplementedError if the integrity error is not caught. + + **Important note**: We assume that a job exists. + + """ + # Insert multiple rows in a single execute call + stmt = insert(JobToPilotMapping).values(job_to_pilot_mapping) + + try: + await self.conn.execute(stmt) + except IntegrityError as e: + if "foreign key" in str(e.orig).lower(): + raise PilotNotFoundError( + detail="at least one of these pilots do not exist", + ) from e + + if ( + "duplicate entry" in str(e.orig).lower() + or "unique constraint" in str(e.orig).lower() + ): + raise PilotAlreadyAssociatedWithJobError( + detail="at least one of these pilots is already associated with a given job." + ) from e + + # Other errors to catch + raise NotImplementedError( + "Engine Specific error not caught" + str(e) + ) from e + + # ----------------------------- Delete Functions ----------------------------- + + async def delete_pilots(self, pilot_ids: list[int]): + """Destructive function. Delete pilots.""" + stmt = delete(PilotAgents).where(PilotAgents.pilot_id.in_(pilot_ids)) + + await self.conn.execute(stmt) + + async def remove_jobs_from_pilots(self, pilot_ids: list[int]): + """Destructive function. De-associate jobs and pilots.""" + stmt = delete(JobToPilotMapping).where( + JobToPilotMapping.pilot_id.in_(pilot_ids) + ) + + await self.conn.execute(stmt) + + async def delete_pilot_logs(self, pilot_ids: list[int]): + """Destructive function. Remove logs from pilots.""" + stmt = delete(PilotOutput).where(PilotOutput.pilot_id.in_(pilot_ids)) + + await self.conn.execute(stmt) + + # ----------------------------- Update Functions ----------------------------- + + async def update_pilot_fields( + self, pilot_stamps_to_fields_mapping: list[PilotFieldsMapping] + ): + """Bulk update pilots with a mapping. + + pilot_stamps_to_fields_mapping format: + ```py + [ + { + "PilotStamp": pilot_stamp, + "BenchMark": bench_mark, + "StatusReason": pilot_reason, + "AccountingSent": accounting_sent, + "Status": status, + "CurrentJobID": current_job_id, + "Queue": queue, + ... + } + ] + ``` + + The mapping helps to update multiple fields at a time. + + Raises PilotNotFoundError if one of the pilots is not found. + """ + stmt = ( + update(PilotAgents) + .where(PilotAgents.pilot_stamp == bindparam("b_pilot_stamp")) + .values( + { + key: bindparam(key) + for key in pilot_stamps_to_fields_mapping[0] + .model_dump(exclude_none=True) + .keys() + if key != "PilotStamp" + } + ) + ) + + values = [ + { + **{"b_pilot_stamp": mapping.PilotStamp}, + **mapping.model_dump(exclude={"PilotStamp"}, exclude_none=True), + } + for mapping in pilot_stamps_to_fields_mapping + ] + + res = await self.conn.execute(stmt, values) + + if res.rowcount != len(pilot_stamps_to_fields_mapping): + raise PilotNotFoundError("at least one of the given pilot does not exist.") + + # ----------------------------- Search Functions ----------------------------- + + async def search_pilots( + self, + parameters: list[str] | None, + search: list[SearchSpec], + sorts: list[SortSpec], + *, + distinct: bool = False, + per_page: int = 100, + page: int | None = None, + ) -> tuple[int, list[dict[Any, Any]]]: + """Search for pilot information in the database.""" + return await self._search( + table=PilotAgents, + parameters=parameters, + search=search, + sorts=sorts, + distinct=distinct, + per_page=per_page, + page=page, + ) + + async def search_pilot_to_job_mapping( + self, + parameters: list[str] | None, + search: list[SearchSpec], + sorts: list[SortSpec], + *, + distinct: bool = False, + per_page: int = 100, + page: int | None = None, + ) -> tuple[int, list[dict[Any, Any]]]: + """Search for jobs that are associated with pilots.""" + return await self._search( + table=JobToPilotMapping, + parameters=parameters, + search=search, + sorts=sorts, + distinct=distinct, + per_page=per_page, + page=page, + ) + + async def pilot_summary( + self, group_by: list[str], search: list[SearchSpec] + ) -> list[dict[str, str | int]]: + """Get a summary of the pilots.""" + return await self._summary(table=PilotAgents, group_by=group_by, search=search) diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py b/diracx-db/src/diracx/db/sql/pilots/schema.py similarity index 92% rename from diracx-db/src/diracx/db/sql/pilot_agents/schema.py rename to diracx-db/src/diracx/db/sql/pilots/schema.py index bff7c460c..af087f1f8 100644 --- a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py +++ b/diracx-db/src/diracx/db/sql/pilots/schema.py @@ -10,6 +10,8 @@ ) from sqlalchemy.orm import declarative_base +from diracx.core.models import PilotStatus + from ..utils import Column, EnumBackedBool, NullColumn PilotAgentsDBBase = declarative_base() @@ -31,12 +33,13 @@ class PilotAgents(PilotAgentsDBBase): benchmark = Column("BenchMark", Double, default=0.0) submission_time = NullColumn("SubmissionTime", DateTime) last_update_time = NullColumn("LastUpdateTime", DateTime) - status = Column("Status", String(32), default="Unknown") + status = Column("Status", String(32), default=PilotStatus.UNKNOWN) status_reason = Column("StatusReason", String(255), default="Unknown") accounting_sent = Column("AccountingSent", EnumBackedBool(), default=False) __table_args__ = ( Index("PilotJobReference", "PilotJobReference"), + Index("PilotStamp", "PilotStamp"), Index("Status", "Status"), Index("Statuskey", "GridSite", "DestinationSite", "Status"), ) diff --git a/diracx-db/src/diracx/db/sql/utils/__init__.py b/diracx-db/src/diracx/db/sql/utils/__init__.py index 5cbb31b3f..53b3f3c96 100644 --- a/diracx-db/src/diracx/db/sql/utils/__init__.py +++ b/diracx-db/src/diracx/db/sql/utils/__init__.py @@ -7,21 +7,25 @@ apply_search_filters, apply_sort_constraints, ) -from .functions import hash, substract_date, utcnow +from .functions import ( + hash, + substract_date, + utcnow, +) from .types import Column, DateNowColumn, EnumBackedBool, EnumColumn, NullColumn __all__ = ( "_get_columns", - "utcnow", + "apply_search_filters", + "apply_sort_constraints", + "BaseSQLDB", "Column", - "NullColumn", "DateNowColumn", - "BaseSQLDB", "EnumBackedBool", "EnumColumn", - "apply_search_filters", - "apply_sort_constraints", - "substract_date", "hash", + "NullColumn", + "substract_date", "SQLDBUnavailableError", + "utcnow", ) diff --git a/diracx-db/tests/pilot_agents/test_pilot_agents_db.py b/diracx-db/tests/pilot_agents/test_pilot_agents_db.py deleted file mode 100644 index 3ca989885..000000000 --- a/diracx-db/tests/pilot_agents/test_pilot_agents_db.py +++ /dev/null @@ -1,30 +0,0 @@ -from __future__ import annotations - -import pytest - -from diracx.db.sql.pilot_agents.db import PilotAgentsDB - - -@pytest.fixture -async def pilot_agents_db(tmp_path) -> PilotAgentsDB: - agents_db = PilotAgentsDB("sqlite+aiosqlite:///:memory:") - async with agents_db.engine_context(): - async with agents_db.engine.begin() as conn: - await conn.run_sync(agents_db.metadata.create_all) - yield agents_db - - -async def test_insert_and_select(pilot_agents_db: PilotAgentsDB): - async with pilot_agents_db as pilot_agents_db: - # Add a pilot reference - refs = [f"ref_{i}" for i in range(10)] - stamps = [f"stamp_{i}" for i in range(10)] - stamp_dict = dict(zip(refs, stamps)) - - await pilot_agents_db.add_pilot_references( - refs, "test_vo", grid_type="DIRAC", pilot_stamps=stamp_dict - ) - - await pilot_agents_db.add_pilot_references( - refs, "test_vo", grid_type="DIRAC", pilot_stamps=None - ) diff --git a/diracx-db/tests/pilot_agents/__init__.py b/diracx-db/tests/pilots/__init__.py similarity index 100% rename from diracx-db/tests/pilot_agents/__init__.py rename to diracx-db/tests/pilots/__init__.py diff --git a/diracx-db/tests/pilots/test_pilot_management.py b/diracx-db/tests/pilots/test_pilot_management.py new file mode 100644 index 000000000..1e7397b39 --- /dev/null +++ b/diracx-db/tests/pilots/test_pilot_management.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest + +from diracx.core.exceptions import ( + PilotAlreadyAssociatedWithJobError, +) +from diracx.core.models import ( + PilotFieldsMapping, + PilotStatus, +) +from diracx.db.sql.pilots.db import PilotAgentsDB + +from .utils import ( + add_stamps, # noqa: F401 + create_old_pilots_environment, # noqa: F401 + create_timed_pilots, # noqa: F401 + get_pilot_jobs_ids_by_pilot_id, + get_pilots_by_stamp, +) + +MAIN_VO = "lhcb" +N = 100 + + +@pytest.fixture +async def pilot_db(tmp_path): + agents_db = PilotAgentsDB("sqlite+aiosqlite:///:memory:") + async with agents_db.engine_context(): + async with agents_db.engine.begin() as conn: + await conn.run_sync(agents_db.metadata.create_all) + yield agents_db + + +@pytest.mark.asyncio +async def test_insert_and_select(pilot_db: PilotAgentsDB): + async with pilot_db as pilot_db: + # Add pilots + refs = [f"ref_{i}" for i in range(10)] + stamps = [f"stamp_{i}" for i in range(10)] + pilot_references = dict(zip(stamps, refs)) + + await pilot_db.add_pilots( + stamps, MAIN_VO, grid_type="DIRAC", pilot_references=pilot_references + ) + + # Accept duplicates because it is checked by the logic + await pilot_db.add_pilots( + stamps, MAIN_VO, grid_type="DIRAC", pilot_references=None + ) + + +@pytest.mark.asyncio +async def test_insert_and_delete(pilot_db: PilotAgentsDB): + async with pilot_db as pilot_db: + # Add pilots + refs = [f"ref_{i}" for i in range(2)] + stamps = [f"stamp_{i}" for i in range(2)] + pilot_references = dict(zip(stamps, refs)) + + await pilot_db.add_pilots( + stamps, MAIN_VO, grid_type="DIRAC", pilot_references=pilot_references + ) + + # Works, the pilots exists + res = await get_pilots_by_stamp(pilot_db, [stamps[0]]) + await get_pilots_by_stamp(pilot_db, [stamps[0]]) + + # We delete the first pilot + await pilot_db.delete_pilots([res[0]["PilotID"]]) + + # We get the 2nd pilot that is not delete (no error) + await get_pilots_by_stamp(pilot_db, [stamps[1]]) + # We get the 1st pilot that is delete (error) + + assert not await get_pilots_by_stamp(pilot_db, [stamps[0]]) + + +@pytest.mark.asyncio +async def test_insert_and_select_single_then_modify(pilot_db: PilotAgentsDB): + async with pilot_db as pilot_db: + pilot_stamp = "stamp-test" + await pilot_db.add_pilots( + vo=MAIN_VO, + pilot_stamps=[pilot_stamp], + grid_type="grid-type", + ) + + res = await get_pilots_by_stamp(pilot_db, [pilot_stamp]) + assert len(res) == 1 + pilot = res[0] + + # Assert values + assert pilot["VO"] == MAIN_VO + assert pilot["PilotStamp"] == pilot_stamp + assert pilot["GridType"] == "grid-type" + assert pilot["BenchMark"] == 0.0 + assert pilot["Status"] == PilotStatus.SUBMITTED + assert pilot["StatusReason"] == "Unknown" + assert not pilot["AccountingSent"] + + # + # Modify a pilot, then check if every change is done + # + await pilot_db.update_pilot_fields( + [ + PilotFieldsMapping( + PilotStamp=pilot_stamp, + BenchMark=1.0, + StatusReason="NewReason", + AccountingSent=True, + Status=PilotStatus.WAITING, + ) + ] + ) + + res = await get_pilots_by_stamp(pilot_db, [pilot_stamp]) + assert len(res) == 1 + pilot = res[0] + + # Set values + assert pilot["VO"] == MAIN_VO + assert pilot["PilotStamp"] == pilot_stamp + assert pilot["GridType"] == "grid-type" + assert pilot["BenchMark"] == 1.0 + assert pilot["Status"] == PilotStatus.WAITING + assert pilot["StatusReason"] == "NewReason" + assert pilot["AccountingSent"] + + +@pytest.mark.asyncio +async def test_associate_pilot_with_job_and_get_it(pilot_db: PilotAgentsDB): + """We will proceed in few steps. + + 1. Create a pilot + 2. Verify that he is not associated with any job + 3. Associate with jobs + 4. Verify that he is associate with this job + 5. Associate with jobs that he already has and two that he has not + 6. Associate with jobs that he has not, but were involved in a crash + """ + async with pilot_db as pilot_db: + pilot_stamp = "stamp-test" + # Add pilot + await pilot_db.add_pilots( + vo=MAIN_VO, + pilot_stamps=[pilot_stamp], + grid_type="grid-type", + ) + + res = await get_pilots_by_stamp(pilot_db, [pilot_stamp]) + assert len(res) == 1 + pilot = res[0] + pilot_id = pilot["PilotID"] + + # Verify that he has no jobs + assert len(await get_pilot_jobs_ids_by_pilot_id(pilot_db, pilot_id)) == 0 + + now = datetime.now(tz=timezone.utc) + + # Associate pilot with jobs + pilot_jobs = [1, 2, 3] + # Prepare the list of dictionaries for bulk insertion + job_to_pilot_mapping = [ + {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} + for job_id in pilot_jobs + ] + await pilot_db.add_jobs_to_pilot(job_to_pilot_mapping) + + # Verify that he has all jobs + db_jobs = await get_pilot_jobs_ids_by_pilot_id(pilot_db, pilot_id) + # We test both length and if every job is included if for any reason we have duplicates + assert all(job in db_jobs for job in pilot_jobs) + assert len(pilot_jobs) == len(db_jobs) + + # Associate pilot with a job that he already has, and one that he has not + pilot_jobs = [10, 1, 5] + with pytest.raises(PilotAlreadyAssociatedWithJobError): + # Prepare the list of dictionaries for bulk insertion + job_to_pilot_mapping = [ + {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} + for job_id in pilot_jobs + ] + await pilot_db.add_jobs_to_pilot(job_to_pilot_mapping) + + # Associate pilot with jobs that he has not, but was previously in an error + # To test that the rollback worked + pilot_jobs = [5, 10] + # Prepare the list of dictionaries for bulk insertion + job_to_pilot_mapping = [ + {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} + for job_id in pilot_jobs + ] + await pilot_db.add_jobs_to_pilot(job_to_pilot_mapping) diff --git a/diracx-db/tests/pilots/test_query.py b/diracx-db/tests/pilots/test_query.py new file mode 100644 index 000000000..be80f0179 --- /dev/null +++ b/diracx-db/tests/pilots/test_query.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +import pytest + +from diracx.core.exceptions import InvalidQueryError +from diracx.core.models import ( + PilotFieldsMapping, + PilotStatus, + ScalarSearchOperator, + ScalarSearchSpec, + SortDirection, + SortSpec, + VectorSearchOperator, + VectorSearchSpec, +) +from diracx.db.sql.pilots.db import PilotAgentsDB + +MAIN_VO = "lhcb" +N = 100 + + +@pytest.fixture +async def pilot_db(tmp_path): + agents_db = PilotAgentsDB("sqlite+aiosqlite:///:memory:") + async with agents_db.engine_context(): + async with agents_db.engine.begin() as conn: + await conn.run_sync(agents_db.metadata.create_all) + yield agents_db + + +PILOT_REASONS = [ + "I was sick", + "I can't, I have a pony.", + "I was shopping", + "I was sleeping", +] + +PILOT_STATUSES = list(PilotStatus) + + +@pytest.fixture +async def populated_pilot_db(pilot_db): + async with pilot_db as pilot_db: + # Add pilots + refs = [f"ref_{i + 1}" for i in range(N)] + stamps = [f"stamp_{i + 1}" for i in range(N)] + pilot_references = dict(zip(stamps, refs)) + + vo = MAIN_VO + + await pilot_db.add_pilots( + stamps, vo, grid_type="DIRAC", pilot_references=pilot_references + ) + + await pilot_db.update_pilot_fields( + [ + PilotFieldsMapping( + PilotStamp=pilot_stamp, + BenchMark=i**2, + StatusReason=PILOT_REASONS[i % len(PILOT_REASONS)], + AccountingSent=True, + Status=PILOT_STATUSES[i % len(PILOT_STATUSES)], + CurrentJobID=i, + Queue=f"queue_{i}", + ) + for i, pilot_stamp in enumerate(stamps) + ] + ) + + yield pilot_db + + +async def test_search_parameters(populated_pilot_db): + """Test that we can search specific parameters for pilots in the database.""" + async with populated_pilot_db as pilot_db: + # Search a specific parameter: PilotID + total, result = await pilot_db.search_pilots(["PilotID"], [], []) + assert total == N + assert result + for r in result: + assert r.keys() == {"PilotID"} + + # Search a specific parameter: Status + total, result = await pilot_db.search_pilots(["Status"], [], []) + assert total == N + assert result + for r in result: + assert r.keys() == {"Status"} + + # Search for multiple parameters: PilotID, Status + total, result = await pilot_db.search_pilots(["PilotID", "Status"], [], []) + assert total == N + assert result + for r in result: + assert r.keys() == {"PilotID", "Status"} + + # Search for a specific parameter but use distinct: Status + total, result = await pilot_db.search_pilots(["Status"], [], [], distinct=True) + assert total == len(PILOT_STATUSES) + assert result + + # Search for a non-existent parameter: Dummy + with pytest.raises(InvalidQueryError): + total, result = await pilot_db.search_pilots(["Dummy"], [], []) + + +async def test_search_conditions(populated_pilot_db): + """Test that we can search for specific pilots in the database.""" + async with populated_pilot_db as pilot_db: + # Search a specific scalar condition: PilotID eq 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=3 + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 1 + assert result + assert len(result) == 1 + assert result[0]["PilotID"] == 3 + + # Search a specific scalar condition: PilotID lt 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.LESS_THAN, value=3 + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 2 + assert result + assert len(result) == 2 + assert result[0]["PilotID"] == 1 + assert result[1]["PilotID"] == 2 + + # Search a specific scalar condition: PilotID neq 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.NOT_EQUAL, value=3 + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 99 + assert result + assert len(result) == 99 + assert all(r["PilotID"] != 3 for r in result) + + # Search a specific scalar condition: PilotID eq 5873 (does not exist) + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=5873 + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert not result + + # Search a specific vector condition: PilotID in 1,2,3 + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 3] + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 3 + assert result + assert len(result) == 3 + assert all(r["PilotID"] in [1, 2, 3] for r in result) + + # Search a specific vector condition: PilotID in 1,2,5873 (one of them does not exist) + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 5873] + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 2 + assert result + assert len(result) == 2 + assert all(r["PilotID"] in [1, 2] for r in result) + + # Search a specific vector condition: PilotID not in 1,2,3 + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.NOT_IN, values=[1, 2, 3] + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 97 + assert result + assert len(result) == 97 + assert all(r["PilotID"] not in [1, 2, 3] for r in result) + + # Search a specific vector condition: PilotID not in 1,2,5873 (one of them does not exist) + condition = VectorSearchSpec( + parameter="PilotID", + operator=VectorSearchOperator.NOT_IN, + values=[1, 2, 5873], + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 98 + assert result + assert len(result) == 98 + assert all(r["PilotID"] not in [1, 2] for r in result) + + # Search for multiple conditions based on different parameters: PilotID eq 70, PilotID in 4,5,6 + condition1 = ScalarSearchSpec( + parameter="PilotStamp", operator=ScalarSearchOperator.EQUAL, value="stamp_5" + ) + condition2 = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6] + ) + total, result = await pilot_db.search_pilots([], [condition1, condition2], []) + assert total == 1 + assert result + assert len(result) == 1 + assert result[0]["PilotID"] == 5 + assert result[0]["PilotStamp"] == "stamp_5" + + # Search for multiple conditions based on the same parameter: PilotID eq 70, PilotID in 4,5,6 + condition1 = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=70 + ) + condition2 = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6] + ) + total, result = await pilot_db.search_pilots([], [condition1, condition2], []) + assert total == 0 + assert not result + + +async def test_search_sorts(populated_pilot_db): + """Test that we can search for pilots in the database and sort the results.""" + async with populated_pilot_db as pilot_db: + # Search and sort by PilotID in ascending order + sort = SortSpec(parameter="PilotID", direction=SortDirection.ASC) + total, result = await pilot_db.search_pilots([], [], [sort]) + assert total == N + assert result + for i, r in enumerate(result): + assert r["PilotID"] == i + 1 + + # Search and sort by PilotID in descending order + sort = SortSpec(parameter="PilotID", direction=SortDirection.DESC) + total, result = await pilot_db.search_pilots([], [], [sort]) + assert total == N + assert result + for i, r in enumerate(result): + assert r["PilotID"] == N - i + + # Search and sort by PilotStamp in ascending order + sort = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC) + total, result = await pilot_db.search_pilots([], [], [sort]) + assert total == N + assert result + # Assert that stamp_10 is before stamp_2 because of the lexicographical order + assert result[2]["PilotStamp"] == "stamp_100" + assert result[12]["PilotStamp"] == "stamp_2" + + # Search and sort by PilotStamp in descending order + sort = SortSpec(parameter="PilotStamp", direction=SortDirection.DESC) + total, result = await pilot_db.search_pilots([], [], [sort]) + assert total == N + assert result + # Assert that stamp_10 is before stamp_2 because of the lexicographical order + assert result[97]["PilotStamp"] == "stamp_100" + assert result[87]["PilotStamp"] == "stamp_2" + + # Search and sort by PilotStamp in ascending order and PilotID in descending order + sort1 = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC) + sort2 = SortSpec(parameter="PilotID", direction=SortDirection.DESC) + total, result = await pilot_db.search_pilots([], [], [sort1, sort2]) + assert total == N + assert result + assert result[0]["PilotStamp"] == "stamp_1" + assert result[0]["PilotID"] == 1 + assert result[99]["PilotStamp"] == "stamp_99" + assert result[99]["PilotID"] == 99 + + +@pytest.mark.parametrize( + "per_page, page, expected_len, expected_first_id, expect_exception", + [ + (10, 1, 10, 1, None), # Page 1 + (10, 2, 10, 11, None), # Page 2 + (10, 10, 10, 91, None), # Page 10 + (50, 2, 50, 51, None), # Page 2 with 50 per page + (10, 11, 0, None, None), # Page beyond range, should return empty + (10, 0, None, None, InvalidQueryError), # Invalid page + (0, 1, None, None, InvalidQueryError), # Invalid per_page + ], +) +async def test_search_pagination( + populated_pilot_db, + per_page, + page, + expected_len, + expected_first_id, + expect_exception, +): + """Test pagination logic in pilot search.""" + async with populated_pilot_db as pilot_db: + if expect_exception: + with pytest.raises(expect_exception): + await pilot_db.search_pilots([], [], [], per_page=per_page, page=page) + else: + total, result = await pilot_db.search_pilots( + [], [], [], per_page=per_page, page=page + ) + assert total == N + if expected_len == 0: + assert not result + else: + assert result + assert len(result) == expected_len + assert result[0]["PilotID"] == expected_first_id diff --git a/diracx-db/tests/pilots/utils.py b/diracx-db/tests/pilots/utils.py new file mode 100644 index 000000000..793310d0d --- /dev/null +++ b/diracx-db/tests/pilots/utils.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +import pytest +from sqlalchemy import update + +from diracx.core.models import ( + ScalarSearchOperator, + ScalarSearchSpec, + VectorSearchOperator, + VectorSearchSpec, +) +from diracx.db.sql.pilots.db import PilotAgentsDB +from diracx.db.sql.pilots.schema import PilotAgents + +MAIN_VO = "lhcb" +N = 100 + +# ------------ Fetching data ------------ + + +async def get_pilots_by_stamp( + pilot_db: PilotAgentsDB, pilot_stamps: list[str], parameters: list[str] = [] +) -> list[dict[Any, Any]]: + _, pilots = await pilot_db.search_pilots( + parameters=parameters, + search=[ + VectorSearchSpec( + parameter="PilotStamp", + operator=VectorSearchOperator.IN, + values=pilot_stamps, + ) + ], + sorts=[], + distinct=True, + per_page=1000, + ) + + return pilots + + +async def get_pilot_jobs_ids_by_pilot_id( + pilot_db: PilotAgentsDB, pilot_id: int +) -> list[int]: + _, jobs = await pilot_db.search_pilot_to_job_mapping( + parameters=["JobID"], + search=[ + ScalarSearchSpec( + parameter="PilotID", + operator=ScalarSearchOperator.EQUAL, + value=pilot_id, + ) + ], + sorts=[], + distinct=True, + per_page=10000, + ) + + return [job["JobID"] for job in jobs] + + +# ------------ Creating data ------------ + + +@pytest.fixture +async def add_stamps(pilot_db): + async def _add_stamps(start_n=0): + async with pilot_db as db: + # Add pilots + refs = [f"ref_{i}" for i in range(start_n, start_n + N)] + stamps = [f"stamp_{i}" for i in range(start_n, start_n + N)] + pilot_references = dict(zip(stamps, refs)) + + vo = MAIN_VO + + await db.add_pilots( + stamps, vo, grid_type="DIRAC", pilot_references=pilot_references + ) + + return await get_pilots_by_stamp(db, stamps) + + return _add_stamps + + +@pytest.fixture +async def create_timed_pilots(pilot_db, add_stamps): + async def _create_timed_pilots( + old_date: datetime, aborted: bool = False, start_n=0 + ): + # Get pilots + pilots = await add_stamps(start_n) + + async with pilot_db as db: + # Update manually their age + # Collect PilotStamps + pilot_stamps = [pilot["PilotStamp"] for pilot in pilots] + + stmt = ( + update(PilotAgents) + .where(PilotAgents.pilot_stamp.in_(pilot_stamps)) + .values(SubmissionTime=old_date) + ) + + if aborted: + stmt = stmt.values(Status="Aborted") + + res = await db.conn.execute(stmt) + assert res.rowcount == len(pilot_stamps) + + pilots = await get_pilots_by_stamp(db, pilot_stamps) + return pilots + + return _create_timed_pilots + + +@pytest.fixture +async def create_old_pilots_environment(pilot_db, create_timed_pilots): + non_aborted_recent = await create_timed_pilots( + datetime(2025, 1, 1, tzinfo=timezone.utc), False, N + ) + aborted_recent = await create_timed_pilots( + datetime(2025, 1, 1, tzinfo=timezone.utc), True, 2 * N + ) + + aborted_very_old = await create_timed_pilots( + datetime(2003, 3, 10, tzinfo=timezone.utc), True, 3 * N + ) + non_aborted_very_old = await create_timed_pilots( + datetime(2003, 3, 10, tzinfo=timezone.utc), False, 4 * N + ) + + pilot_number = 4 * N + + assert pilot_number == ( + len(non_aborted_recent) + + len(aborted_recent) + + len(aborted_very_old) + + len(non_aborted_very_old) + ) + + # Phase 0. Verify that we have the right environment + async with pilot_db as pilot_db: + # Ensure that we can get every pilot (only get first of each group) + await get_pilots_by_stamp(pilot_db, [non_aborted_recent[0]["PilotStamp"]]) + await get_pilots_by_stamp(pilot_db, [aborted_recent[0]["PilotStamp"]]) + await get_pilots_by_stamp(pilot_db, [aborted_very_old[0]["PilotStamp"]]) + await get_pilots_by_stamp(pilot_db, [non_aborted_very_old[0]["PilotStamp"]]) + + return non_aborted_recent, aborted_recent, non_aborted_very_old, aborted_very_old diff --git a/diracx-db/tests/test_dummy_db.py b/diracx-db/tests/test_dummy_db.py index f94eda5b7..8e324a28e 100644 --- a/diracx-db/tests/test_dummy_db.py +++ b/diracx-db/tests/test_dummy_db.py @@ -149,6 +149,7 @@ async def test_failed_transaction(dummy_db): assert result # This will raise an exception and the transaction will be rolled back + result = await dummy_db.summary(["unexistingfieldraisinganerror"], []) assert result[0]["count"] == 10 diff --git a/diracx-logic/src/diracx/logic/pilots/__init__.py b/diracx-logic/src/diracx/logic/pilots/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/diracx-logic/src/diracx/logic/pilots/management.py b/diracx-logic/src/diracx/logic/pilots/management.py new file mode 100644 index 000000000..3c8d1251b --- /dev/null +++ b/diracx-logic/src/diracx/logic/pilots/management.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + +from diracx.core.exceptions import PilotAlreadyExistsError, PilotNotFoundError +from diracx.core.models import PilotFieldsMapping +from diracx.db.sql import PilotAgentsDB + +from .query import ( + get_outdated_pilots, + get_pilot_ids_by_stamps, + get_pilot_jobs_ids_by_pilot_id, + get_pilots_by_stamp, +) + + +async def register_new_pilots( + pilot_db: PilotAgentsDB, + pilot_stamps: list[str], + vo: str, + grid_type: str, + grid_site: str, + destination_site: str, + status: str, + pilot_job_references: dict[str, str] | None, +): + # [IMPORTANT] Check unicity of pilot stamps + # If a pilot already exists, we raise an error (transaction will rollback) + existing_pilots = await get_pilots_by_stamp( + pilot_db=pilot_db, pilot_stamps=pilot_stamps + ) + + # If we found pilots from the list, this means some pilots already exists + if len(existing_pilots) > 0: + found_keys = {pilot["PilotStamp"] for pilot in existing_pilots} + + raise PilotAlreadyExistsError( + f"The following pilots already exist: {found_keys}" + ) + + await pilot_db.add_pilots( + pilot_stamps=pilot_stamps, + vo=vo, + grid_type=grid_type, + grid_site=grid_site, + destination_site=destination_site, + pilot_references=pilot_job_references, + status=status, + ) + + +async def delete_pilots( + pilot_db: PilotAgentsDB, + pilot_stamps: list[str] | None = None, + age_in_days: int | None = None, + delete_only_aborted: bool = True, + vo_constraint: str | None = None, +): + if pilot_stamps: + pilot_ids = await get_pilot_ids_by_stamps( + pilot_db=pilot_db, pilot_stamps=pilot_stamps, allow_missing=True + ) + else: + assert age_in_days + assert vo_constraint + + cutoff_date = datetime.now(tz=timezone.utc) - timedelta(days=age_in_days) + + pilots = await get_outdated_pilots( + pilot_db=pilot_db, + cutoff_date=cutoff_date, + only_aborted=delete_only_aborted, + parameters=["PilotID"], + vo_constraint=vo_constraint, + ) + + pilot_ids = [pilot["PilotID"] for pilot in pilots] + + await pilot_db.remove_jobs_from_pilots(pilot_ids) + await pilot_db.delete_pilot_logs(pilot_ids) + await pilot_db.delete_pilots(pilot_ids) + + +async def update_pilots_fields( + pilot_db: PilotAgentsDB, pilot_stamps_to_fields_mapping: list[PilotFieldsMapping] +): + await pilot_db.update_pilot_fields(pilot_stamps_to_fields_mapping) + + +async def add_jobs_to_pilot( + pilot_db: PilotAgentsDB, pilot_stamp: str, job_ids: list[int] +): + pilot_ids = await get_pilot_ids_by_stamps( + pilot_db=pilot_db, pilot_stamps=[pilot_stamp] + ) + pilot_id = pilot_ids[0] + + now = datetime.now(tz=timezone.utc) + + # Prepare the list of dictionaries for bulk insertion + job_to_pilot_mapping = [ + {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} for job_id in job_ids + ] + + await pilot_db.add_jobs_to_pilot( + job_to_pilot_mapping=job_to_pilot_mapping, + ) + + +async def get_pilot_jobs_ids_by_stamp( + pilot_db: PilotAgentsDB, pilot_stamp: str +) -> list[int]: + """Fetch pilot jobs by stamp.""" + try: + pilot_ids = await get_pilot_ids_by_stamps( + pilot_db=pilot_db, pilot_stamps=[pilot_stamp] + ) + pilot_id = pilot_ids[0] + except PilotNotFoundError: + return [] + + return await get_pilot_jobs_ids_by_pilot_id(pilot_db=pilot_db, pilot_id=pilot_id) diff --git a/diracx-logic/src/diracx/logic/pilots/query.py b/diracx-logic/src/diracx/logic/pilots/query.py new file mode 100644 index 000000000..b6cf504d7 --- /dev/null +++ b/diracx-logic/src/diracx/logic/pilots/query.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from diracx.core.exceptions import PilotNotFoundError +from diracx.core.models import ( + PilotStatus, + ScalarSearchOperator, + ScalarSearchSpec, + SearchParams, + SearchSpec, + SummaryParams, + VectorSearchOperator, + VectorSearchSpec, +) +from diracx.db.sql import PilotAgentsDB + +MAX_PER_PAGE = 10000 + + +async def search( + pilot_db: PilotAgentsDB, + user_vo: str, + page: int = 1, + per_page: int = 100, + body: SearchParams | None = None, +) -> tuple[int, list[dict[str, Any]]]: + """Retrieve information about jobs.""" + # Apply a limit to per_page to prevent abuse of the API + if per_page > MAX_PER_PAGE: + per_page = MAX_PER_PAGE + + if body is None: + body = SearchParams() + + body.search.append( + ScalarSearchSpec( + parameter="VO", operator=ScalarSearchOperator.EQUAL, value=user_vo + ) + ) + + total, pilots = await pilot_db.search_pilots( + body.parameters, + body.search, + body.sort, + distinct=body.distinct, + page=page, + per_page=per_page, + ) + + return total, pilots + + +async def get_pilots_by_stamp( + pilot_db: PilotAgentsDB, + pilot_stamps: list[str], + parameters: list[str] = [], + allow_missing: bool = True, +) -> list[dict[Any, Any]]: + """Get pilots by their stamp. + + If `allow_missing` is set to False, if a pilot is missing, PilotNotFoundError will be raised. + """ + if parameters: + parameters.append("PilotStamp") + + _, pilots = await pilot_db.search_pilots( + parameters=parameters, + search=[ + VectorSearchSpec( + parameter="PilotStamp", + operator=VectorSearchOperator.IN, + values=pilot_stamps, + ) + ], + sorts=[], + distinct=True, + per_page=MAX_PER_PAGE, + ) + + # allow_missing is set as True by default to mark explicitly when we allow or not + if not allow_missing: + # Custom handling, to see which pilot_stamp does not exist (if so, say which one) + found_keys = {row["PilotStamp"] for row in pilots} + missing = set(pilot_stamps) - found_keys + + if missing: + raise PilotNotFoundError( + detail=str(missing), + ) + + return pilots + + +async def get_pilot_ids_by_stamps( + pilot_db: PilotAgentsDB, pilot_stamps: list[str], allow_missing=False +) -> list[int]: + pilots = await get_pilots_by_stamp( + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + parameters=["PilotID"], + allow_missing=allow_missing, + ) + + return [pilot["PilotID"] for pilot in pilots] + + +async def get_pilot_jobs_ids_by_pilot_id( + pilot_db: PilotAgentsDB, pilot_id: int +) -> list[int]: + _, jobs = await pilot_db.search_pilot_to_job_mapping( + parameters=["JobID"], + search=[ + ScalarSearchSpec( + parameter="PilotID", + operator=ScalarSearchOperator.EQUAL, + value=pilot_id, + ) + ], + sorts=[], + distinct=True, + per_page=MAX_PER_PAGE, + ) + + return [job["JobID"] for job in jobs] + + +async def get_pilot_ids_by_job_id(pilot_db: PilotAgentsDB, job_id: int) -> list[int]: + _, pilots = await pilot_db.search_pilot_to_job_mapping( + parameters=["PilotID"], + search=[ + ScalarSearchSpec( + parameter="JobID", + operator=ScalarSearchOperator.EQUAL, + value=job_id, + ) + ], + sorts=[], + distinct=True, + per_page=MAX_PER_PAGE, + ) + + return [pilot["PilotID"] for pilot in pilots] + + +async def get_outdated_pilots( + pilot_db: PilotAgentsDB, + cutoff_date: datetime, + vo_constraint: str, + only_aborted: bool = True, + parameters: list[str] = [], +): + query: list[SearchSpec] = [ + ScalarSearchSpec( + parameter="SubmissionTime", + operator=ScalarSearchOperator.LESS_THAN, + value=cutoff_date, + ), + # Add VO to avoid deleting other VO's pilots + ScalarSearchSpec( + parameter="VO", operator=ScalarSearchOperator.EQUAL, value=vo_constraint + ), + ] + + if only_aborted: + query.append( + ScalarSearchSpec( + parameter="Status", + operator=ScalarSearchOperator.EQUAL, + value=PilotStatus.ABORTED, + ) + ) + + _, pilots = await pilot_db.search_pilots( + parameters=parameters, search=query, sorts=[] + ) + + return pilots + + +async def summary(pilot_db: PilotAgentsDB, body: SummaryParams, vo: str): + """Show information suitable for plotting.""" + body.search.append( + { + "parameter": "VO", + "operator": ScalarSearchOperator.EQUAL, + "value": vo, + } + ) + return await pilot_db.pilot_summary(body.grouping, body.search) diff --git a/diracx-routers/pyproject.toml b/diracx-routers/pyproject.toml index 6f554c74e..2038223ce 100644 --- a/diracx-routers/pyproject.toml +++ b/diracx-routers/pyproject.toml @@ -46,10 +46,12 @@ auth = "diracx.routers.auth:router" config = "diracx.routers.configuration:router" health = "diracx.routers.health:router" jobs = "diracx.routers.jobs:router" +pilots = "diracx.routers.pilots:router" [project.entry-points."diracx.access_policies"] WMSAccessPolicy = "diracx.routers.jobs.access_policies:WMSAccessPolicy" SandboxAccessPolicy = "diracx.routers.jobs.access_policies:SandboxAccessPolicy" +PilotManagementAccessPolicy = "diracx.routers.pilots.access_policies:PilotManagementAccessPolicy" # Minimum version of the client supported [project.entry-points."diracx.min_client_version"] diff --git a/diracx-routers/src/diracx/routers/pilots/__init__.py b/diracx-routers/src/diracx/routers/pilots/__init__.py new file mode 100644 index 000000000..03f9b8422 --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/__init__.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +import logging + +from ..fastapi_classes import DiracxRouter +from .management import router as management_router +from .query import router as query_router + +logger = logging.getLogger(__name__) + +router = DiracxRouter() +router.include_router(management_router) +router.include_router(query_router) diff --git a/diracx-routers/src/diracx/routers/pilots/access_policies.py b/diracx-routers/src/diracx/routers/pilots/access_policies.py new file mode 100644 index 000000000..61a324f79 --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/access_policies.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +from collections.abc import Callable +from enum import StrEnum, auto +from typing import Annotated + +from fastapi import Depends, HTTPException, status + +from diracx.core.models import VectorSearchOperator, VectorSearchSpec +from diracx.core.properties import GENERIC_PILOT, SERVICE_ADMINISTRATOR +from diracx.db.sql.job.db import JobDB +from diracx.db.sql.pilots.db import PilotAgentsDB +from diracx.logic.pilots.query import get_pilots_by_stamp +from diracx.routers.access_policies import BaseAccessPolicy +from diracx.routers.utils.users import AuthorizedUserInfo + + +class ActionType(StrEnum): + # Change some pilot fields + MANAGE_PILOTS = auto() + # Read some pilot info + READ_PILOT_FIELDS = auto() + + +class PilotManagementAccessPolicy(BaseAccessPolicy): + """Rules: + * Every user can access data about his VO + * An administrator can modify a pilot. + """ + + @staticmethod + async def policy( + policy_name: str, + user_info: AuthorizedUserInfo, + /, + *, + action: ActionType | None = None, + pilot_db: PilotAgentsDB | None = None, + pilot_stamps: list[str] | None = None, + job_db: JobDB | None = None, + job_ids: list[int] | None = None, + allow_legacy_pilots: bool = False, + ): + assert action, "action is a mandatory parameter" + + # Users can query + # NOTE: Add into queries a VO constraint + # To manage pilots, user have to be an admin + # In some special cases (described with allow_legacy_pilots), we can allow pilots + if action == ActionType.MANAGE_PILOTS: + # To make it clear, we separate + is_an_admin = SERVICE_ADMINISTRATOR in user_info.properties + is_a_pilot_if_allowed = ( + allow_legacy_pilots and GENERIC_PILOT in user_info.properties + ) + + if not is_an_admin and not is_a_pilot_if_allowed: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You don't have the permission to manage pilots.", + ) + + if action == ActionType.READ_PILOT_FIELDS: + if GENERIC_PILOT in user_info.properties: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Pilots can't read other pilots info.", + ) + + # + # Additional checks if job_ids or pilot_stamps are provided + # + + # First, if job_ids are provided, we check who is the owner + if job_db and job_ids: + job_owners = await job_db.summary( + ["Owner", "VO"], + [ + VectorSearchSpec( + parameter="JobID", + operator=VectorSearchOperator.IN, + values=job_ids, + ) + ], + ) + + expected_owner = { + "Owner": user_info.preferred_username, + "VO": user_info.vo, + "count": len(set(job_ids)), + } + # All the jobs belong to the user doing the query + # and all of them are present + if not job_owners == [expected_owner]: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You don't have the rights to modify a pilot.", + ) + + # This is for example when we submit pilots, we use the user VO, so no need to verify + if pilot_db and pilot_stamps: + # Else, check its VO + pilots = await get_pilots_by_stamp( + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + parameters=["VO"], + allow_missing=True, + ) + + if len(pilots) != len(pilot_stamps): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="At least one pilot does not exist.", + ) + + if not all(pilot["VO"] == user_info.vo for pilot in pilots): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You don't have access to all pilots.", + ) + + +CheckPilotManagementPolicyCallable = Annotated[ + Callable, Depends(PilotManagementAccessPolicy.check) +] diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py new file mode 100644 index 000000000..21ff63796 --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -0,0 +1,260 @@ +from __future__ import annotations + +from http import HTTPStatus +from typing import Annotated + +from fastapi import Body, Depends, HTTPException, Query, status + +from diracx.core.exceptions import ( + PilotAlreadyExistsError, +) +from diracx.core.models import ( + PilotFieldsMapping, + PilotStatus, +) +from diracx.core.properties import GENERIC_PILOT +from diracx.logic.pilots.management import ( + delete_pilots as delete_pilots_bl, +) +from diracx.logic.pilots.management import ( + get_pilot_jobs_ids_by_stamp, + register_new_pilots, + update_pilots_fields, +) +from diracx.logic.pilots.query import get_pilot_ids_by_job_id +from diracx.routers.utils.users import AuthorizedUserInfo, verify_dirac_access_token + +from ..dependencies import JobDB, PilotAgentsDB +from ..fastapi_classes import DiracxRouter +from .access_policies import ( + ActionType, + CheckPilotManagementPolicyCallable, +) + +router = DiracxRouter() + + +@router.post("/") +async def add_pilot_stamps( + pilot_db: PilotAgentsDB, + pilot_stamps: Annotated[ + list[str], + Body(description="List of the pilot stamps we want to add to the db."), + ], + vo: Annotated[str, Body(description="Pilot virtual organization.")], + check_permissions: CheckPilotManagementPolicyCallable, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + grid_type: Annotated[str, Body(description="Grid type of the pilots.")] = "Dirac", + grid_site: Annotated[str, Body(description="Pilots grid site.")] = "Unknown", + destination_site: Annotated[ + str, Body(description="Pilots destination site.") + ] = "NotAssigned", + pilot_references: Annotated[ + dict[str, str] | None, + Body(description="Association of a pilot reference with a pilot stamp."), + ] = None, + pilot_status: Annotated[ + PilotStatus, Body(description="Status of the pilots.") + ] = PilotStatus.SUBMITTED, +): + """Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + """ + # TODO: Verify that grid types, sites, destination sites, etc. are valids + await check_permissions( + action=ActionType.MANAGE_PILOTS, + allow_legacy_pilots=True, # dirac-admin-add-pilot + ) + + # Prevent someone who stole a pilot X509 to create thousands of pilots at a time + # (It would be still able to create thousands of pilots, but slower) + if GENERIC_PILOT in user_info.properties: + if len(pilot_stamps) != 1: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="As a pilot, you can only create yourself.", + ) + + try: + await register_new_pilots( + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + vo=vo, + grid_type=grid_type, + grid_site=grid_site, + destination_site=destination_site, + pilot_job_references=pilot_references, + status=pilot_status, + ) + except PilotAlreadyExistsError as e: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e + + +@router.delete("/", status_code=HTTPStatus.NO_CONTENT) +async def delete_pilots( + pilot_db: PilotAgentsDB, + check_permissions: CheckPilotManagementPolicyCallable, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + pilot_stamps: Annotated[ + list[str] | None, Query(description="Stamps of the pilots we want to delete.") + ] = None, + age_in_days: Annotated[ + int | None, + Query( + description=( + "The number of days that define the maximum age of pilots to be deleted." + "Pilots older than this age will be considered for deletion." + ) + ), + ] = None, + delete_only_aborted: Annotated[ + bool, + Query( + description=( + "Flag indicating whether to only delete pilots whose status is 'Aborted'." + "If set to True, only pilots with the 'Aborted' status will be deleted." + "It is set by default as True to avoid any mistake." + "This flag is only used for deletion by time." + ) + ), + ] = False, +): + """Endpoint to delete a pilot. + + Two features: + + 1. Or you provide pilot_stamps, so you can delete pilots by their stamp + 2. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. + + Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. + """ + vo_constraint: str | None = None + + # If we delete by pilot_stamps, we check that we can access them + # Else, we add a constraint to the request, to avoid deleting pilots from another VO + if pilot_stamps: + await check_permissions( + action=ActionType.MANAGE_PILOTS, + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + ) + else: + vo_constraint = user_info.vo + + if not pilot_stamps and not age_in_days: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="pilot_stamps or age_in_days have to be provided.", + ) + + await delete_pilots_bl( + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, + vo_constraint=vo_constraint, + ) + + +EXAMPLE_UPDATE_FIELDS = { + "Update the BenchMark field": { + "summary": "Update BenchMark", + "description": "Update only the BenchMark for one pilot.", + "value": { + "pilot_stamps_to_fields_mapping": [ + {"PilotStamp": "the_pilot_stamp", "BenchMark": 1.0} + ] + }, + }, + "Update multiple statuses": { + "summary": "Update multiple pilots", + "description": "Update multiple pilots statuses.", + "value": { + "pilot_stamps_to_fields_mapping": [ + {"PilotStamp": "the_first_pilot_stamp", "Status": "Waiting"}, + {"PilotStamp": "the_second_pilot_stamp", "Status": "Waiting"}, + ] + }, + }, +} + + +@router.patch("/metadata", status_code=HTTPStatus.NO_CONTENT) +async def update_pilot_fields( + pilot_stamps_to_fields_mapping: Annotated[ + list[PilotFieldsMapping], + Body( + description="(pilot_stamp, pilot_fields) mapping to change.", + embed=True, + openapi_examples=EXAMPLE_UPDATE_FIELDS, # type: ignore + ), + ], + pilot_db: PilotAgentsDB, + check_permissions: CheckPilotManagementPolicyCallable, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], +): + """Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + """ + # Ensures stamps validity + pilot_stamps = [mapping.PilotStamp for mapping in pilot_stamps_to_fields_mapping] + await check_permissions( + action=ActionType.MANAGE_PILOTS, + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + allow_legacy_pilots=True, # dirac-admin-add-pilot + ) + + # Prevent someone who stole a pilot X509 to modify thousands of pilots at a time + # (It would be still able to modify thousands of pilots, but slower) + # We are not able to affirm that this pilot modifies itself + if GENERIC_PILOT in user_info.properties: + if len(pilot_stamps) != 1: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="As a pilot, you can only modify yourself.", + ) + + await update_pilots_fields( + pilot_db=pilot_db, + pilot_stamps_to_fields_mapping=pilot_stamps_to_fields_mapping, + ) + + +@router.get("/jobs") +async def get_pilot_jobs( + pilot_db: PilotAgentsDB, + job_db: JobDB, + check_permissions: CheckPilotManagementPolicyCallable, + pilot_stamp: Annotated[ + str | None, Query(description="The stamp of the pilot.") + ] = None, + job_id: Annotated[int | None, Query(description="The ID of the job.")] = None, +) -> list[int]: + """Endpoint only for admins, to get jobs of a pilot.""" + if pilot_stamp: + # Check VO + await check_permissions( + action=ActionType.READ_PILOT_FIELDS, + pilot_db=pilot_db, + pilot_stamps=[pilot_stamp], + ) + + return await get_pilot_jobs_ids_by_stamp( + pilot_db=pilot_db, + pilot_stamp=pilot_stamp, + ) + elif job_id: + # Check job owner + await check_permissions( + action=ActionType.READ_PILOT_FIELDS, job_db=job_db, job_ids=[job_id] + ) + + return await get_pilot_ids_by_job_id(pilot_db=pilot_db, job_id=job_id) + + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="You must provide either pilot_stamp or job_id", + ) diff --git a/diracx-routers/src/diracx/routers/pilots/query.py b/diracx-routers/src/diracx/routers/pilots/query.py new file mode 100644 index 000000000..56655b46c --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/query.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +from http import HTTPStatus +from typing import Annotated, Any + +from fastapi import Body, Depends, Response + +from diracx.core.models import SearchParams, SummaryParams +from diracx.logic.pilots.query import search as search_bl +from diracx.logic.pilots.query import summary as summary_bl + +from ..dependencies import PilotAgentsDB +from ..fastapi_classes import DiracxRouter +from ..utils.users import AuthorizedUserInfo, verify_dirac_access_token +from .access_policies import ( + ActionType, + CheckPilotManagementPolicyCallable, +) + +router = DiracxRouter() + +EXAMPLE_SEARCHES = { + "Show all": { + "summary": "Show all", + "description": "Shows all pilots the current user has access to.", + "value": {}, + }, + "A specific pilot": { + "summary": "A specific pilot", + "description": "Search for a specific pilot by ID", + "value": {"search": [{"parameter": "PilotID", "operator": "eq", "value": "5"}]}, + }, + "Get ordered pilot statuses": { + "summary": "Get ordered pilot statuses", + "description": "Get only pilot statuses for specific pilots, ordered by status", + "value": { + "parameters": ["PilotID", "Status"], + "search": [ + {"parameter": "PilotID", "operator": "in", "values": ["6", "2", "3"]} + ], + "sort": [{"parameter": "PilotID", "direction": "asc"}], + }, + }, +} + + +EXAMPLE_RESPONSES: dict[int | str, dict[str, Any]] = { + 200: { + "description": "List of matching results", + "content": { + "application/json": { + "example": [ + { + "PilotID": 3, + "SubmissionTime": "2023-05-25T07:03:35.602654", + "LastUpdateTime": "2023-05-25T07:03:35.602656", + "Status": "RUNNING", + "GridType": "Dirac", + "BenchMark": 1.0, + }, + { + "PilotID": 5, + "SubmissionTime": "2023-06-25T07:03:35.602654", + "LastUpdateTime": "2023-07-25T07:03:35.602652", + "Status": "RUNNING", + "GridType": "Dirac", + "BenchMark": 63.1, + }, + ] + } + }, + }, + 206: { + "description": "Partial Content. Only a part of the requested range could be served.", + "headers": { + "Content-Range": { + "description": "The range of pilots returned in this response", + "schema": {"type": "string", "example": "pilots 0-1/4"}, + } + }, + "model": list[dict[str, Any]], + "content": { + "application/json": { + "example": [ + { + "PilotID": 3, + "SubmissionTime": "2023-05-25T07:03:35.602654", + "LastUpdateTime": "2023-05-25T07:03:35.602656", + "Status": "RUNNING", + "GridType": "Dirac", + "BenchMark": 1.0, + }, + { + "PilotID": 5, + "SubmissionTime": "2023-06-25T07:03:35.602654", + "LastUpdateTime": "2023-07-25T07:03:35.602652", + "Status": "RUNNING", + "GridType": "Dirac", + "BenchMark": 63.1, + }, + ] + } + }, + }, +} + + +@router.post("/search", responses=EXAMPLE_RESPONSES) +async def search( + pilot_db: PilotAgentsDB, + check_permissions: CheckPilotManagementPolicyCallable, + response: Response, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + page: int = 1, + per_page: int = 100, + body: Annotated[ + SearchParams | None, Body(openapi_examples=EXAMPLE_SEARCHES) # type: ignore + ] = None, +) -> list[dict[str, Any]]: + """Retrieve information about pilots.""" + # Inspired by /api/jobs/query + await check_permissions(action=ActionType.READ_PILOT_FIELDS) + + total, pilots = await search_bl( + pilot_db=pilot_db, + user_vo=user_info.vo, + page=page, + per_page=per_page, + body=body, + ) + + # Set the Content-Range header if needed + # https://datatracker.ietf.org/doc/html/rfc7233#section-4 + + # No pilots found but there are pilots for the requested search + # https://datatracker.ietf.org/doc/html/rfc7233#section-4.4 + if len(pilots) == 0 and total > 0: + response.headers["Content-Range"] = f"pilots */{total}" + response.status_code = HTTPStatus.REQUESTED_RANGE_NOT_SATISFIABLE + + # The total number of pilots is greater than the number of pilots returned + # https://datatracker.ietf.org/doc/html/rfc7233#section-4.2 + elif len(pilots) < total: + first_idx = per_page * (page - 1) + last_idx = min(first_idx + len(pilots), total) - 1 if total > 0 else 0 + response.headers["Content-Range"] = f"pilots {first_idx}-{last_idx}/{total}" + response.status_code = HTTPStatus.PARTIAL_CONTENT + return pilots + + +@router.post("/summary") +async def summary( + pilot_db: PilotAgentsDB, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + body: SummaryParams, + check_permissions: CheckPilotManagementPolicyCallable, +): + """Show information suitable for plotting.""" + await check_permissions(action=ActionType.READ_PILOT_FIELDS) + + return await summary_bl( + pilot_db=pilot_db, + body=body, + vo=user_info.vo, + ) diff --git a/diracx-routers/tests/pilots/test_pilot_creation.py b/diracx-routers/tests/pilots/test_pilot_creation.py new file mode 100644 index 000000000..c055727c9 --- /dev/null +++ b/diracx-routers/tests/pilots/test_pilot_creation.py @@ -0,0 +1,284 @@ +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest +from sqlalchemy import update + +from diracx.core.models import ( + PilotFieldsMapping, + PilotStatus, +) +from diracx.db.sql import PilotAgentsDB +from diracx.db.sql.pilots.schema import PilotAgents + +pytestmark = pytest.mark.enabled_dependencies( + [ + "PilotCredentialsAccessPolicy", + "DevelopmentSettings", + "AuthDB", + "AuthSettings", + "ConfigSource", + "BaseAccessPolicy", + "PilotAgentsDB", + "PilotManagementAccessPolicy", + "JobDB", + ] +) + +MAIN_VO = "lhcb" +N = 100 + + +@pytest.fixture +def normal_test_client(client_factory): + with client_factory.normal_user() as client: + yield client + + +async def test_create_pilots(normal_test_client): + # Lots of request, to validate that it returns the credentials in the same order as the input references + pilot_stamps = [f"stamps_{i}" for i in range(N)] + + # -------------- Bulk insert -------------- + body = {"pilot_stamps": pilot_stamps, "vo": MAIN_VO} + + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 200, r.json() + + # -------------- Register a pilot that already exists, and one that does not -------------- + + body = { + "pilot_stamps": [pilot_stamps[0], pilot_stamps[0] + "_new_one"], + "vo": MAIN_VO, + } + + r = normal_test_client.post( + "/api/pilots/", + json=body, + headers={ + "Content-Type": "application/json", + }, + ) + + assert r.status_code == 409 + assert ( + r.json()["detail"] + == f"The following pilots already exist: {{'{pilot_stamps[0]}'}}" + ) + + # -------------- Register a pilot that does not exists **but** was called before in an error -------------- + # To prove that, if I tried to register a pilot that does not exist with one that already exists, + # i can normally add the one that did not exist before (it should not have added it before) + body = {"pilot_stamps": [pilot_stamps[0] + "_new_one"], "vo": MAIN_VO} + + r = normal_test_client.post( + "/api/pilots/", + json=body, + headers={ + "Content-Type": "application/json", + }, + ) + + assert r.status_code == 200 + + +async def test_create_pilot_and_delete_it(normal_test_client): + pilot_stamp = "stamps_1" + + # -------------- Insert -------------- + body = {"pilot_stamps": [pilot_stamp], "vo": MAIN_VO} + + # Create a pilot + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 200, r.json() + + # -------------- Duplicate -------------- + # Duplicate because it exists, should have 409 + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 409, r.json() + + # -------------- Delete -------------- + params = {"pilot_stamps": [pilot_stamp]} + + # We delete the pilot + r = normal_test_client.delete( + "/api/pilots/", + params=params, + ) + + assert r.status_code == 204 + + # -------------- Insert -------------- + # Create a the same pilot, but works because it does not exist anymore + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 200, r.json() + + +async def test_create_pilot_and_modify_it(normal_test_client): + pilot_stamps = ["stamps_1", "stamp_2"] + + # -------------- Insert -------------- + body = {"pilot_stamps": pilot_stamps, "vo": MAIN_VO} + + # Create pilots + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 200, r.json() + + # -------------- Modify -------------- + # We modify only the first pilot + body = { + "pilot_stamps_to_fields_mapping": [ + PilotFieldsMapping( + PilotStamp=pilot_stamps[0], + BenchMark=1.0, + StatusReason="NewReason", + AccountingSent=True, + Status=PilotStatus.WAITING, + ).model_dump(exclude_unset=True) + ] + } + + r = normal_test_client.patch("/api/pilots/metadata", json=body) + + assert r.status_code == 204 + + body = { + "parameters": [], + "search": [], + "sort": [], + "distinct": True, + } + + r = normal_test_client.post("/api/pilots/search", json=body) + assert r.status_code == 200, r.json() + pilot1 = r.json()[0] + pilot2 = r.json()[1] + + assert pilot1["BenchMark"] == 1.0 + assert pilot1["StatusReason"] == "NewReason" + assert pilot1["AccountingSent"] + assert pilot1["Status"] == PilotStatus.WAITING + + assert pilot2["BenchMark"] != pilot1["BenchMark"] + assert pilot2["StatusReason"] != pilot1["StatusReason"] + assert pilot2["AccountingSent"] != pilot1["AccountingSent"] + assert pilot2["Status"] != pilot1["Status"] + + +@pytest.mark.asyncio +async def test_delete_pilots_by_age_and_stamp(normal_test_client): + # Generate 100 pilot stamps + pilot_stamps = [f"stamp_{i}" for i in range(100)] + + # -------------- Insert all pilots -------------- + body = {"pilot_stamps": pilot_stamps, "vo": MAIN_VO} + r = normal_test_client.post("/api/pilots/", json=body) + assert r.status_code == 200, r.json() + + # -------------- Modify last 50 pilots' fields -------------- + to_modify = pilot_stamps[50:] + mappings = [] + for idx, stamp in enumerate(to_modify): + # First 25 of modified set to ABORTED, others to WAITING + status = PilotStatus.ABORTED if idx < 25 else PilotStatus.WAITING + mapping = PilotFieldsMapping( + PilotStamp=stamp, + BenchMark=idx + 0.1, + StatusReason=f"Reason_{idx}", + AccountingSent=(idx % 2 == 0), + Status=status, + ).model_dump(exclude_unset=True) + mappings.append(mapping) + + r = normal_test_client.patch( + "/api/pilots/metadata", + json={"pilot_stamps_to_fields_mapping": mappings}, + ) + assert r.status_code == 204 + + # -------------- Directly set SubmissionTime to March 14, 2003 for last 50 -------------- + old_date = datetime(2003, 3, 14, tzinfo=timezone.utc) + # Access DB session from normal_test_client fixtures + db = normal_test_client.app.dependency_overrides[PilotAgentsDB.transaction].args[0] + + async with db: + stmt = ( + update(PilotAgents) + .where(PilotAgents.pilot_stamp.in_(to_modify)) + .values(SubmissionTime=old_date) + ) + await db.conn.execute(stmt) + await db.conn.commit() + + # -------------- Verify all 100 pilots exist -------------- + search_body = {"parameters": [], "search": [], "sort": [], "distinct": True} + r = normal_test_client.post("/api/pilots/search", json=search_body) + assert r.status_code == 200, r.json() + assert len(r.json()) == 100 + + # -------------- 1) Delete only old aborted pilots (25 expected) -------------- + # age_in_days large enough to include 2003-03-14 + r = normal_test_client.delete( + "/api/pilots/", + params={"age_in_days": 15, "delete_only_aborted": True}, + ) + assert r.status_code == 204 + # Expect 75 remaining + r = normal_test_client.post("/api/pilots/search", json=search_body) + assert len(r.json()) == 75 + + # -------------- 2) Delete all old pilots (remaining 25 old) -------------- + r = normal_test_client.delete( + "/api/pilots/", + params={"age_in_days": 15}, + ) + assert r.status_code == 204 + + # Expect 50 remaining + r = normal_test_client.post("/api/pilots/search", json=search_body) + assert len(r.json()) == 50 + + # -------------- 3) Delete one recent pilot by stamp -------------- + one_stamp = pilot_stamps[10] + r = normal_test_client.delete("/api/pilots/", params={"pilot_stamps": [one_stamp]}) + assert r.status_code == 204 + # Expect 49 remaining + r = normal_test_client.post("/api/pilots/search", json=search_body) + assert len(r.json()) == 49 + + # -------------- 4) Delete all remaining pilots -------------- + # Collect remaining stamps + remaining = [p["PilotStamp"] for p in r.json()] + r = normal_test_client.delete("/api/pilots/", params={"pilot_stamps": remaining}) + assert r.status_code == 204 + # Expect none remaining + r = normal_test_client.post("/api/pilots/search", json=search_body) + assert r.status_code == 200 + assert len(r.json()) == 0 + + # -------------- 5) Attempt deleting unknown pilot, expect 400 -------------- + r = normal_test_client.delete( + "/api/pilots/", params={"pilot_stamps": ["unknown_stamp"]} + ) + assert r.status_code == 204 diff --git a/diracx-routers/tests/pilots/test_query.py b/diracx-routers/tests/pilots/test_query.py new file mode 100644 index 000000000..c6d5cedb4 --- /dev/null +++ b/diracx-routers/tests/pilots/test_query.py @@ -0,0 +1,414 @@ +"""Inspired by pilots and jobs db search tests.""" + +from __future__ import annotations + +import pytest +from fastapi.testclient import TestClient + +from diracx.core.exceptions import InvalidQueryError +from diracx.core.models import ( + PilotFieldsMapping, + PilotStatus, + ScalarSearchOperator, + ScalarSearchSpec, + SortDirection, + SortSpec, + VectorSearchOperator, + VectorSearchSpec, +) + +pytestmark = pytest.mark.enabled_dependencies( + [ + "AuthSettings", + "ConfigSource", + "DevelopmentSettings", + "PilotAgentsDB", + "PilotManagementAccessPolicy", + ] +) + + +@pytest.fixture +def normal_test_client(client_factory): + with client_factory.normal_user() as client: + yield client + + +MAIN_VO = "lhcb" +N = 100 + +PILOT_REASONS = [ + "I was sick", + "I can't, I have a pony.", + "I was shopping", + "I was sleeping", +] + +PILOT_STATUSES = list(PilotStatus) + + +@pytest.fixture +async def populated_pilot_client(normal_test_client): + pilot_stamps = [f"stamp_{i}" for i in range(1, N + 1)] + + # -------------- Bulk insert -------------- + body = {"vo": MAIN_VO, "pilot_stamps": pilot_stamps} + + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 200, r.json() + + body = { + "pilot_stamps_to_fields_mapping": [ + PilotFieldsMapping( + PilotStamp=pilot_stamp, + BenchMark=i**2, + StatusReason=PILOT_REASONS[i % len(PILOT_REASONS)], + AccountingSent=True, + Status=PILOT_STATUSES[i % len(PILOT_STATUSES)], + CurrentJobID=i, + Queue=f"queue_{i}", + ).model_dump(exclude_unset=True) + for i, pilot_stamp in enumerate(pilot_stamps) + ] + } + + r = normal_test_client.patch("/api/pilots/metadata", json=body) + + assert r.status_code == 204 + + yield normal_test_client + + +async def test_pilot_summary(populated_pilot_client: TestClient): + # Group by StatusReason + r = populated_pilot_client.post( + "/api/pilots/summary", + json={ + "grouping": ["StatusReason"], + }, + ) + + assert r.status_code == 200 + + assert sum([el["count"] for el in r.json()]) == N + assert len(r.json()) == len(PILOT_REASONS) + + # Group by CurrentJobID + r = populated_pilot_client.post( + "/api/pilots/summary", + json={ + "grouping": ["CurrentJobID"], + }, + ) + + assert r.status_code == 200 + + assert all(el["count"] == 1 for el in r.json()) + assert len(r.json()) == N + + # Group by CurrentJobID where BenchMark < 10^2 + r = populated_pilot_client.post( + "/api/pilots/summary", + json={ + "grouping": ["CurrentJobID"], + "search": [{"parameter": "BenchMark", "operator": "lt", "value": 10**2}], + }, + ) + + assert r.status_code == 200, r.json() + + assert all(el["count"] == 1 for el in r.json()) + assert len(r.json()) == 10 + + +@pytest.fixture +async def search(populated_pilot_client): + async def _search( + parameters, conditions, sorts, distinct=False, page=1, per_page=100 + ): + body = { + "parameters": parameters, + "search": conditions, + "sort": sorts, + "distinct": distinct, + } + + params = {"per_page": per_page, "page": page} + + r = populated_pilot_client.post("/api/pilots/search", json=body, params=params) + + if r.status_code == 400: + # If we have a status_code 400, that means that the query failed + raise InvalidQueryError() + + return r.json(), r.headers + + return _search + + +async def test_search_parameters(search): + """Test that we can search specific parameters for pilots.""" + # Search a specific parameter: PilotID + result, headers = await search(["PilotID"], [], []) + assert len(result) == N + assert result + for r in result: + assert r.keys() == {"PilotID"} + assert "Content-Range" not in headers + + # Search a specific parameter: Status + result, headers = await search(["Status"], [], []) + assert len(result) == N + assert result + for r in result: + assert r.keys() == {"Status"} + assert "Content-Range" not in headers + + # Search for multiple parameters: PilotID, Status + result, headers = await search(["PilotID", "Status"], [], []) + assert len(result) == N + assert result + for r in result: + assert r.keys() == {"PilotID", "Status"} + assert "Content-Range" not in headers + + # Search for a specific parameter but use distinct: Status + result, headers = await search(["Status"], [], [], distinct=True) + assert len(result) == len(PILOT_STATUSES) + assert result + assert "Content-Range" not in headers + + # Search for a non-existent parameter: Dummy + with pytest.raises(InvalidQueryError): + result, headers = await search(["Dummy"], [], []) + + +async def test_search_conditions(search): + """Test that we can search for specific pilots.""" + # Search a specific scalar condition: PilotID eq 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=3 + ) + result, headers = await search([], [condition], []) + assert len(result) == 1 + assert result + assert len(result) == 1 + assert result[0]["PilotID"] == 3 + assert "Content-Range" not in headers + + # Search a specific scalar condition: PilotID lt 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.LESS_THAN, value=3 + ) + result, headers = await search([], [condition], []) + assert len(result) == 2 + assert result + assert len(result) == 2 + assert result[0]["PilotID"] == 1 + assert result[1]["PilotID"] == 2 + assert "Content-Range" not in headers + + # Search a specific scalar condition: PilotID neq 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.NOT_EQUAL, value=3 + ) + result, headers = await search([], [condition], []) + assert len(result) == 99 + assert result + assert len(result) == 99 + assert all(r["PilotID"] != 3 for r in result) + assert "Content-Range" not in headers + + # Search a specific scalar condition: PilotID eq 5873 (does not exist) + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=5873 + ) + result, headers = await search([], [condition], []) + assert not result + assert "Content-Range" not in headers + + # Search a specific vector condition: PilotID in 1,2,3 + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 3] + ) + result, headers = await search([], [condition], []) + assert len(result) == 3 + assert result + assert len(result) == 3 + assert all(r["PilotID"] in [1, 2, 3] for r in result) + assert "Content-Range" not in headers + + # Search a specific vector condition: PilotID in 1,2,5873 (one of them does not exist) + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 5873] + ) + result, headers = await search([], [condition], []) + assert len(result) == 2 + assert result + assert len(result) == 2 + assert all(r["PilotID"] in [1, 2] for r in result) + assert "Content-Range" not in headers + + # Search a specific vector condition: PilotID not in 1,2,3 + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.NOT_IN, values=[1, 2, 3] + ) + result, headers = await search([], [condition], []) + assert len(result) == 97 + assert result + assert len(result) == 97 + assert all(r["PilotID"] not in [1, 2, 3] for r in result) + assert "Content-Range" not in headers + + # Search a specific vector condition: PilotID not in 1,2,5873 (one of them does not exist) + condition = VectorSearchSpec( + parameter="PilotID", + operator=VectorSearchOperator.NOT_IN, + values=[1, 2, 5873], + ) + result, headers = await search([], [condition], []) + assert len(result) == 98 + assert result + assert len(result) == 98 + assert all(r["PilotID"] not in [1, 2] for r in result) + assert "Content-Range" not in headers + + # Search for multiple conditions based on different parameters: PilotID eq 70, PilotID in 4,5,6 + condition1 = ScalarSearchSpec( + parameter="PilotStamp", operator=ScalarSearchOperator.EQUAL, value="stamp_5" + ) + condition2 = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6] + ) + result, headers = await search([], [condition1, condition2], []) + + assert result + assert len(result) == 1 + assert result[0]["PilotID"] == 5 + assert result[0]["PilotStamp"] == "stamp_5" + assert "Content-Range" not in headers + + # Search for multiple conditions based on the same parameter: PilotID eq 70, PilotID in 4,5,6 + condition1 = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=70 + ) + condition2 = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6] + ) + result, headers = await search([], [condition1, condition2], []) + assert len(result) == 0 + assert not result + assert "Content-Range" not in headers + + +async def test_search_sorts(search): + """Test that we can search for pilots and sort the results.""" + # Search and sort by PilotID in ascending order + sort = SortSpec(parameter="PilotID", direction=SortDirection.ASC) + result, headers = await search([], [], [sort]) + assert len(result) == N + assert result + for i, r in enumerate(result): + assert r["PilotID"] == i + 1 + assert "Content-Range" not in headers + + # Search and sort by PilotID in descending order + sort = SortSpec(parameter="PilotID", direction=SortDirection.DESC) + result, headers = await search([], [], [sort]) + assert len(result) == N + assert result + for i, r in enumerate(result): + assert r["PilotID"] == N - i + assert "Content-Range" not in headers + + # Search and sort by PilotStamp in ascending order + sort = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC) + result, headers = await search([], [], [sort]) + assert len(result) == N + assert result + # Assert that stamp_10 is before stamp_2 because of the lexicographical order + assert result[2]["PilotStamp"] == "stamp_100" + assert result[12]["PilotStamp"] == "stamp_2" + assert "Content-Range" not in headers + + # Search and sort by PilotStamp in descending order + sort = SortSpec(parameter="PilotStamp", direction=SortDirection.DESC) + result, headers = await search([], [], [sort]) + assert len(result) == N + assert result + # Assert that stamp_10 is before stamp_2 because of the lexicographical order + assert result[97]["PilotStamp"] == "stamp_100" + assert result[87]["PilotStamp"] == "stamp_2" + assert "Content-Range" not in headers + + # Search and sort by PilotStamp in ascending order and PilotID in descending order + sort1 = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC) + sort2 = SortSpec(parameter="PilotID", direction=SortDirection.DESC) + result, headers = await search([], [], [sort1, sort2]) + assert len(result) == N + assert result + assert result[0]["PilotStamp"] == "stamp_1" + assert result[0]["PilotID"] == 1 + assert result[99]["PilotStamp"] == "stamp_99" + assert result[99]["PilotID"] == 99 + assert "Content-Range" not in headers + + +async def test_search_pagination(search): + """Test that we can search for pilots.""" + # Search for the first 10 pilots + result, headers = await search([], [], [], per_page=10, page=1) + assert "Content-Range" in headers + # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert total == N + assert result + assert len(result) == 10 + assert result[0]["PilotID"] == 1 + + # Search for the second 10 pilots + result, headers = await search([], [], [], per_page=10, page=2) + assert "Content-Range" in headers + # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert total == N + assert result + assert len(result) == 10 + assert result[0]["PilotID"] == 11 + + # Search for the last 10 pilots + result, headers = await search([], [], [], per_page=10, page=10) + assert "Content-Range" in headers + # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert result + assert len(result) == 10 + assert result[0]["PilotID"] == 91 + + # Search for the second 50 pilots + result, headers = await search([], [], [], per_page=50, page=2) + assert "Content-Range" in headers + # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert result + assert len(result) == 50 + assert result[0]["PilotID"] == 51 + + # Invalid page number + result, headers = await search([], [], [], per_page=10, page=11) + assert "Content-Range" in headers + # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert not result + + # Invalid page number + with pytest.raises(InvalidQueryError): + result = await search([], [], [], per_page=10, page=0) + + # Invalid per_page number + with pytest.raises(InvalidQueryError): + result = await search([], [], [], per_page=0, page=1) diff --git a/docs/dev/explanations/pilots.md b/docs/dev/explanations/pilots.md new file mode 100644 index 000000000..544e12db5 --- /dev/null +++ b/docs/dev/explanations/pilots.md @@ -0,0 +1,20 @@ +## Presentation + +Pilots are a piece of software that is running on *worker nodes*. There are two types of pilots: "DIRAC pilots", and "DiracX pilots". The first type corresponds to pilots with proxies, sent by DIRAC; and the second type corresponds to pilots with secrets. Both kinds will eventually interact with DiracX using tokens (DIRAC pilots by exchanging their proxies for tokens, DiracX by exchanging their secrets for tokens). + +## Management + +Their management is adapted in DiracX, and each feature has its own route in DiracX. We will split the `/pilots` route into two parts: + +1. `/api/pilots/*` to allow administrators and users to access and modify pilots +2. `/api/pilots/internal/*` is allocated for pilots resources: only DiracX pilots will have access to these resources + +Each part has its own security policy: we want to prevent pilots to access users resources and vice-versa. To differentiate DIRAC pilots from users, we can get their token and compare their properties: `GENERIC_PILOT` is the property that defines a pilot. For DiracX pilots, we can differentiate them by looking at the token structure: they don't have properties, but a "stamp" (their identifier). + +## Endpoints + +We ordered our endpoints like so: + +1. Creation: `POST /api/pilots/` +2. Deletion: `DELETE /api/pilots/` +3. Modification: `PATCH /api/pilots/metadata` diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py index 65282efb6..fdf17b6a3 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py @@ -15,7 +15,14 @@ from . import models as _models from ._configuration import DiracConfiguration from ._utils.serialization import Deserializer, Serializer -from .operations import AuthOperations, ConfigOperations, JobsOperations, LollygagOperations, WellKnownOperations +from .operations import ( + AuthOperations, + ConfigOperations, + JobsOperations, + LollygagOperations, + PilotsOperations, + WellKnownOperations, +) class Dirac: # pylint: disable=client-accepts-api-version-keyword @@ -31,6 +38,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype jobs: _generated.operations.JobsOperations :ivar lollygag: LollygagOperations operations :vartype lollygag: _generated.operations.LollygagOperations + :ivar pilots: PilotsOperations operations + :vartype pilots: _generated.operations.PilotsOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -68,6 +77,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) self.lollygag = LollygagOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> HttpResponse: """Runs the network request through the client's chained policies. diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py index d67986dae..76280797e 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py @@ -15,7 +15,14 @@ from .. import models as _models from .._utils.serialization import Deserializer, Serializer from ._configuration import DiracConfiguration -from .operations import AuthOperations, ConfigOperations, JobsOperations, LollygagOperations, WellKnownOperations +from .operations import ( + AuthOperations, + ConfigOperations, + JobsOperations, + LollygagOperations, + PilotsOperations, + WellKnownOperations, +) class Dirac: # pylint: disable=client-accepts-api-version-keyword @@ -31,6 +38,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype jobs: _generated.aio.operations.JobsOperations :ivar lollygag: LollygagOperations operations :vartype lollygag: _generated.aio.operations.LollygagOperations + :ivar pilots: PilotsOperations operations + :vartype pilots: _generated.aio.operations.PilotsOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -68,6 +77,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) self.lollygag = LollygagOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) def send_request( self, request: HttpRequest, *, stream: bool = False, **kwargs: Any diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py index 572930a93..3408891fc 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py @@ -15,6 +15,7 @@ from ._operations import ConfigOperations # type: ignore from ._operations import JobsOperations # type: ignore from ._operations import LollygagOperations # type: ignore +from ._operations import PilotsOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -26,6 +27,7 @@ "ConfigOperations", "JobsOperations", "LollygagOperations", + "PilotsOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py index 8927a2921..19925b650 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py @@ -55,6 +55,12 @@ build_lollygag_get_gubbins_secrets_request, build_lollygag_get_owner_object_request, build_lollygag_insert_owner_object_request, + build_pilots_add_pilot_stamps_request, + build_pilots_delete_pilots_request, + build_pilots_get_pilot_jobs_request, + build_pilots_search_request, + build_pilots_summary_request, + build_pilots_update_pilot_fields_request, build_well_known_get_installation_metadata_request, build_well_known_get_jwks_request, build_well_known_get_openid_configuration_request, @@ -2523,3 +2529,583 @@ async def get_gubbins_secrets(self, **kwargs: Any) -> Any: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.aio.Dirac`'s + :attr:`pilots` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: AsyncPipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + async def add_pilot_stamps( + self, body: _models.BodyPilotsAddPilotStamps, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Is either a BodyPilotsAddPilotStamps type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsAddPilotStamps") + + _request = build_pilots_add_pilot_stamps_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + async def delete_pilots( + self, + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any + ) -> None: + """Delete Pilots. + + Endpoint to delete a pilot. + + Two features: + + + #. Or you provide pilot_stamps, so you can delete pilots by their stamp + #. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. + + Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. + + :keyword pilot_stamps: Stamps of the pilots we want to delete. Default value is None. + :paramtype pilot_stamps: list[str] + :keyword age_in_days: The number of days that define the maximum age of pilots to be + deleted.Pilots older than this age will be considered for deletion. Default value is None. + :paramtype age_in_days: int + :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is + 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by + default as True to avoid any mistake.This flag is only used for deletion by time. Default value + is False. + :paramtype delete_only_aborted: bool + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[None] = kwargs.pop("cls", None) + + _request = build_pilots_delete_pilots_request( + pilot_stamps=pilot_stamps, + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + async def update_pilot_fields( + self, body: _models.BodyPilotsUpdatePilotFields, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def update_pilot_fields( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def update_pilot_fields( + self, body: Union[_models.BodyPilotsUpdatePilotFields, IO[bytes]], **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Is either a BodyPilotsUpdatePilotFields type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsUpdatePilotFields") + + _request = build_pilots_update_pilot_fields_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @distributed_trace_async + async def get_pilot_jobs( + self, *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any + ) -> List[int]: + """Get Pilot Jobs. + + Endpoint only for admins, to get jobs of a pilot. + + :keyword pilot_stamp: The stamp of the pilot. Default value is None. + :paramtype pilot_stamp: str + :keyword job_id: The ID of the job. Default value is None. + :paramtype job_id: int + :return: list of int + :rtype: list[int] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[List[int]] = kwargs.pop("cls", None) + + _request = build_pilots_get_pilot_jobs_request( + pilot_stamp=pilot_stamp, + job_id=job_id, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("[int]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @overload + async def search( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def search( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def search( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + async def summary( + self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: ~_generated.models.SummaryParams + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def summary(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "SummaryParams") + + _request = build_pilots_summary_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py index d8e29cfeb..7bdd59b63 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py @@ -16,6 +16,8 @@ BodyAuthGetOidcTokenGrantType, BodyJobsRescheduleJobs, BodyJobsUnassignBulkJobsSandboxes, + BodyPilotsAddPilotStamps, + BodyPilotsUpdatePilotFields, ExtendedMetadata, GroupInfo, HTTPValidationError, @@ -27,6 +29,7 @@ JobMetaDataAccountedFlag, JobStatusUpdate, OpenIDConfiguration, + PilotFieldsMapping, SandboxDownloadResponse, SandboxInfo, SandboxUploadResponse, @@ -52,6 +55,7 @@ from ._enums import ( # type: ignore ChecksumAlgorithm, JobStatus, + PilotStatus, SandboxFormat, SandboxType, ScalarSearchOperator, @@ -67,6 +71,8 @@ "BodyAuthGetOidcTokenGrantType", "BodyJobsRescheduleJobs", "BodyJobsUnassignBulkJobsSandboxes", + "BodyPilotsAddPilotStamps", + "BodyPilotsUpdatePilotFields", "ExtendedMetadata", "GroupInfo", "HTTPValidationError", @@ -78,6 +84,7 @@ "JobMetaDataAccountedFlag", "JobStatusUpdate", "OpenIDConfiguration", + "PilotFieldsMapping", "SandboxDownloadResponse", "SandboxInfo", "SandboxUploadResponse", @@ -100,6 +107,7 @@ "VectorSearchSpecValues", "ChecksumAlgorithm", "JobStatus", + "PilotStatus", "SandboxFormat", "SandboxType", "ScalarSearchOperator", diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_enums.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_enums.py index 663d9c951..23edf99d3 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_enums.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_enums.py @@ -34,6 +34,19 @@ class JobStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): RESCHEDULED = "Rescheduled" +class PilotStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """PilotStatus.""" + + SUBMITTED = "Submitted" + WAITING = "Waiting" + RUNNING = "Running" + DONE = "Done" + FAILED = "Failed" + DELETED = "Deleted" + ABORTED = "Aborted" + UNKNOWN = "Unknown" + + class SandboxFormat(str, Enum, metaclass=CaseInsensitiveEnumMeta): """SandboxFormat.""" diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py index faaea49b4..2e8717cb6 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py @@ -146,6 +146,109 @@ def __init__(self, *, job_ids: List[int], **kwargs: Any) -> None: self.job_ids = job_ids +class BodyPilotsAddPilotStamps(_serialization.Model): + """Body_pilots_add_pilot_stamps. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamps: List of the pilot stamps we want to add to the db. Required. + :vartype pilot_stamps: list[str] + :ivar vo: Pilot virtual organization. Required. + :vartype vo: str + :ivar grid_type: Grid type of the pilots. + :vartype grid_type: str + :ivar grid_site: Pilots grid site. + :vartype grid_site: str + :ivar destination_site: Pilots destination site. + :vartype destination_site: str + :ivar pilot_references: Association of a pilot reference with a pilot stamp. + :vartype pilot_references: dict[str, str] + :ivar pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", "Running", + "Done", "Failed", "Deleted", "Aborted", and "Unknown". + :vartype pilot_status: str or ~_generated.models.PilotStatus + """ + + _validation = { + "pilot_stamps": {"required": True}, + "vo": {"required": True}, + } + + _attribute_map = { + "pilot_stamps": {"key": "pilot_stamps", "type": "[str]"}, + "vo": {"key": "vo", "type": "str"}, + "grid_type": {"key": "grid_type", "type": "str"}, + "grid_site": {"key": "grid_site", "type": "str"}, + "destination_site": {"key": "destination_site", "type": "str"}, + "pilot_references": {"key": "pilot_references", "type": "{str}"}, + "pilot_status": {"key": "pilot_status", "type": "str"}, + } + + def __init__( + self, + *, + pilot_stamps: List[str], + vo: str, + grid_type: str = "Dirac", + grid_site: str = "Unknown", + destination_site: str = "NotAssigned", + pilot_references: Optional[Dict[str, str]] = None, + pilot_status: Optional[Union[str, "_models.PilotStatus"]] = None, + **kwargs: Any + ) -> None: + """ + :keyword pilot_stamps: List of the pilot stamps we want to add to the db. Required. + :paramtype pilot_stamps: list[str] + :keyword vo: Pilot virtual organization. Required. + :paramtype vo: str + :keyword grid_type: Grid type of the pilots. + :paramtype grid_type: str + :keyword grid_site: Pilots grid site. + :paramtype grid_site: str + :keyword destination_site: Pilots destination site. + :paramtype destination_site: str + :keyword pilot_references: Association of a pilot reference with a pilot stamp. + :paramtype pilot_references: dict[str, str] + :keyword pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", + "Running", "Done", "Failed", "Deleted", "Aborted", and "Unknown". + :paramtype pilot_status: str or ~_generated.models.PilotStatus + """ + super().__init__(**kwargs) + self.pilot_stamps = pilot_stamps + self.vo = vo + self.grid_type = grid_type + self.grid_site = grid_site + self.destination_site = destination_site + self.pilot_references = pilot_references + self.pilot_status = pilot_status + + +class BodyPilotsUpdatePilotFields(_serialization.Model): + """Body_pilots_update_pilot_fields. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamps_to_fields_mapping: (pilot_stamp, pilot_fields) mapping to change. Required. + :vartype pilot_stamps_to_fields_mapping: list[~_generated.models.PilotFieldsMapping] + """ + + _validation = { + "pilot_stamps_to_fields_mapping": {"required": True}, + } + + _attribute_map = { + "pilot_stamps_to_fields_mapping": {"key": "pilot_stamps_to_fields_mapping", "type": "[PilotFieldsMapping]"}, + } + + def __init__(self, *, pilot_stamps_to_fields_mapping: List["_models.PilotFieldsMapping"], **kwargs: Any) -> None: + """ + :keyword pilot_stamps_to_fields_mapping: (pilot_stamp, pilot_fields) mapping to change. + Required. + :paramtype pilot_stamps_to_fields_mapping: list[~_generated.models.PilotFieldsMapping] + """ + super().__init__(**kwargs) + self.pilot_stamps_to_fields_mapping = pilot_stamps_to_fields_mapping + + class ExtendedMetadata(_serialization.Model): """ExtendedMetadata. @@ -907,6 +1010,102 @@ def __init__( self.code_challenge_methods_supported = code_challenge_methods_supported +class PilotFieldsMapping(_serialization.Model): + """All the fields that a user can modify on a Pilot (except PilotStamp). + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamp: Pilotstamp. Required. + :vartype pilot_stamp: str + :ivar status_reason: Statusreason. + :vartype status_reason: str + :ivar status: PilotStatus. Known values are: "Submitted", "Waiting", "Running", "Done", + "Failed", "Deleted", "Aborted", and "Unknown". + :vartype status: str or ~_generated.models.PilotStatus + :ivar bench_mark: Benchmark. + :vartype bench_mark: float + :ivar destination_site: Destinationsite. + :vartype destination_site: str + :ivar queue: Queue. + :vartype queue: str + :ivar grid_site: Gridsite. + :vartype grid_site: str + :ivar grid_type: Gridtype. + :vartype grid_type: str + :ivar accounting_sent: Accountingsent. + :vartype accounting_sent: bool + :ivar current_job_id: Currentjobid. + :vartype current_job_id: int + """ + + _validation = { + "pilot_stamp": {"required": True}, + } + + _attribute_map = { + "pilot_stamp": {"key": "PilotStamp", "type": "str"}, + "status_reason": {"key": "StatusReason", "type": "str"}, + "status": {"key": "Status", "type": "str"}, + "bench_mark": {"key": "BenchMark", "type": "float"}, + "destination_site": {"key": "DestinationSite", "type": "str"}, + "queue": {"key": "Queue", "type": "str"}, + "grid_site": {"key": "GridSite", "type": "str"}, + "grid_type": {"key": "GridType", "type": "str"}, + "accounting_sent": {"key": "AccountingSent", "type": "bool"}, + "current_job_id": {"key": "CurrentJobID", "type": "int"}, + } + + def __init__( + self, + *, + pilot_stamp: str, + status_reason: Optional[str] = None, + status: Optional[Union[str, "_models.PilotStatus"]] = None, + bench_mark: Optional[float] = None, + destination_site: Optional[str] = None, + queue: Optional[str] = None, + grid_site: Optional[str] = None, + grid_type: Optional[str] = None, + accounting_sent: Optional[bool] = None, + current_job_id: Optional[int] = None, + **kwargs: Any + ) -> None: + """ + :keyword pilot_stamp: Pilotstamp. Required. + :paramtype pilot_stamp: str + :keyword status_reason: Statusreason. + :paramtype status_reason: str + :keyword status: PilotStatus. Known values are: "Submitted", "Waiting", "Running", "Done", + "Failed", "Deleted", "Aborted", and "Unknown". + :paramtype status: str or ~_generated.models.PilotStatus + :keyword bench_mark: Benchmark. + :paramtype bench_mark: float + :keyword destination_site: Destinationsite. + :paramtype destination_site: str + :keyword queue: Queue. + :paramtype queue: str + :keyword grid_site: Gridsite. + :paramtype grid_site: str + :keyword grid_type: Gridtype. + :paramtype grid_type: str + :keyword accounting_sent: Accountingsent. + :paramtype accounting_sent: bool + :keyword current_job_id: Currentjobid. + :paramtype current_job_id: int + """ + super().__init__(**kwargs) + self.pilot_stamp = pilot_stamp + self.status_reason = status_reason + self.status = status + self.bench_mark = bench_mark + self.destination_site = destination_site + self.queue = queue + self.grid_site = grid_site + self.grid_type = grid_type + self.accounting_sent = accounting_sent + self.current_job_id = current_job_id + + class SandboxDownloadResponse(_serialization.Model): """SandboxDownloadResponse. diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py index 572930a93..3408891fc 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py @@ -15,6 +15,7 @@ from ._operations import ConfigOperations # type: ignore from ._operations import JobsOperations # type: ignore from ._operations import LollygagOperations # type: ignore +from ._operations import PilotsOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -26,6 +27,7 @@ "ConfigOperations", "JobsOperations", "LollygagOperations", + "PilotsOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py index fa5e665ce..4358ecf51 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py @@ -635,6 +635,124 @@ def build_lollygag_get_gubbins_secrets_request(**kwargs: Any) -> HttpRequest: # return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) +def build_pilots_add_pilot_stamps_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + +def build_pilots_delete_pilots_request( + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any +) -> HttpRequest: + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + # Construct URL + _url = "/api/pilots/" + + # Construct parameters + if pilot_stamps is not None: + _params["pilot_stamps"] = _SERIALIZER.query("pilot_stamps", pilot_stamps, "[str]") + if age_in_days is not None: + _params["age_in_days"] = _SERIALIZER.query("age_in_days", age_in_days, "int") + if delete_only_aborted is not None: + _params["delete_only_aborted"] = _SERIALIZER.query("delete_only_aborted", delete_only_aborted, "bool") + + return HttpRequest(method="DELETE", url=_url, params=_params, **kwargs) + + +def build_pilots_update_pilot_fields_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + # Construct URL + _url = "/api/pilots/metadata" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + + return HttpRequest(method="PATCH", url=_url, headers=_headers, **kwargs) + + +def build_pilots_get_pilot_jobs_request( + *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/jobs" + + # Construct parameters + if pilot_stamp is not None: + _params["pilot_stamp"] = _SERIALIZER.query("pilot_stamp", pilot_stamp, "str") + if job_id is not None: + _params["job_id"] = _SERIALIZER.query("job_id", job_id, "int") + + # Construct headers + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_pilots_search_request(*, page: int = 1, per_page: int = 100, **kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/search" + + # Construct parameters + if page is not None: + _params["page"] = _SERIALIZER.query("page", page, "int") + if per_page is not None: + _params["per_page"] = _SERIALIZER.query("per_page", per_page, "int") + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_pilots_summary_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/summary" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + class WellKnownOperations: """ .. warning:: @@ -3088,3 +3206,579 @@ def get_gubbins_secrets(self, **kwargs: Any) -> Any: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.Dirac`'s + :attr:`pilots` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: PipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + def add_pilot_stamps( + self, body: _models.BodyPilotsAddPilotStamps, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Is either a BodyPilotsAddPilotStamps type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsAddPilotStamps") + + _request = build_pilots_add_pilot_stamps_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + def delete_pilots( # pylint: disable=inconsistent-return-statements + self, + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any + ) -> None: + """Delete Pilots. + + Endpoint to delete a pilot. + + Two features: + + + #. Or you provide pilot_stamps, so you can delete pilots by their stamp + #. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. + + Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. + + :keyword pilot_stamps: Stamps of the pilots we want to delete. Default value is None. + :paramtype pilot_stamps: list[str] + :keyword age_in_days: The number of days that define the maximum age of pilots to be + deleted.Pilots older than this age will be considered for deletion. Default value is None. + :paramtype age_in_days: int + :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is + 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by + default as True to avoid any mistake.This flag is only used for deletion by time. Default value + is False. + :paramtype delete_only_aborted: bool + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[None] = kwargs.pop("cls", None) + + _request = build_pilots_delete_pilots_request( + pilot_stamps=pilot_stamps, + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + def update_pilot_fields( + self, body: _models.BodyPilotsUpdatePilotFields, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def update_pilot_fields(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def update_pilot_fields( # pylint: disable=inconsistent-return-statements + self, body: Union[_models.BodyPilotsUpdatePilotFields, IO[bytes]], **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Is either a BodyPilotsUpdatePilotFields type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsUpdatePilotFields") + + _request = build_pilots_update_pilot_fields_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @distributed_trace + def get_pilot_jobs( + self, *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any + ) -> List[int]: + """Get Pilot Jobs. + + Endpoint only for admins, to get jobs of a pilot. + + :keyword pilot_stamp: The stamp of the pilot. Default value is None. + :paramtype pilot_stamp: str + :keyword job_id: The ID of the job. Default value is None. + :paramtype job_id: int + :return: list of int + :rtype: list[int] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[List[int]] = kwargs.pop("cls", None) + + _request = build_pilots_get_pilot_jobs_request( + pilot_stamp=pilot_stamp, + job_id=job_id, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("[int]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @overload + def search( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def search( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def search( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + def summary(self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: ~_generated.models.SummaryParams + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def summary(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "SummaryParams") + + _request = build_pilots_summary_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore From 3b9f72809f972bdc5d8f48cb4c4b6f7e5d59a152 Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Tue, 5 Aug 2025 09:23:56 +0200 Subject: [PATCH 02/11] feat: Add custom Dirac branch for integration tests --- .github/workflows/integration.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index e065fad11..28e3d8a05 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -21,7 +21,7 @@ jobs: fail-fast: false matrix: dirac-branch: - - integration + - robin-migrate-client steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 @@ -37,7 +37,7 @@ jobs: - name: Clone DIRAC run: | pip install typer pyyaml gitpython packaging - git clone https://github.com/DIRACGrid/DIRAC.git -b "${{ matrix.dirac-branch }}" /tmp/DIRACRepo + git clone https://github.com/Robin-Van-de-Merghel/DIRAC.git -b "${{ matrix.dirac-branch }}" /tmp/DIRACRepo echo "Current revision: $(git -C /tmp/DIRACRepo rev-parse HEAD)" # We need to cd in the directory for the integration_tests.py to work - name: Prepare environment From 5fdfaa66e45eda68ed47befcbd8f66d6fad03789 Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Tue, 5 Aug 2025 10:04:48 +0200 Subject: [PATCH 03/11] fix: Add more security to the pilot creation router --- .../src/diracx/routers/pilots/management.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py index 21ff63796..a383643d1 100644 --- a/diracx-routers/src/diracx/routers/pilots/management.py +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -12,7 +12,7 @@ PilotFieldsMapping, PilotStatus, ) -from diracx.core.properties import GENERIC_PILOT +from diracx.core.properties import GENERIC_PILOT, JOB_ADMINISTRATOR from diracx.logic.pilots.management import ( delete_pilots as delete_pilots_bl, ) @@ -72,10 +72,17 @@ async def add_pilot_stamps( if GENERIC_PILOT in user_info.properties: if len(pilot_stamps) != 1: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, + status_code=status.HTTP_403_FORBIDDEN, detail="As a pilot, you can only create yourself.", ) + if JOB_ADMINISTRATOR not in user_info.properties: + if not vo == user_info.vo: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You can create pilots only for your VO.", + ) + try: await register_new_pilots( pilot_db=pilot_db, From 7c6ba77b5045e0a2d6c2f9119d3f0e534f297c97 Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Tue, 5 Aug 2025 13:33:43 +0200 Subject: [PATCH 04/11] fix: Fixed a micro typo --- diracx-logic/src/diracx/logic/pilots/management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diracx-logic/src/diracx/logic/pilots/management.py b/diracx-logic/src/diracx/logic/pilots/management.py index 3c8d1251b..a6256a742 100644 --- a/diracx-logic/src/diracx/logic/pilots/management.py +++ b/diracx-logic/src/diracx/logic/pilots/management.py @@ -30,7 +30,7 @@ async def register_new_pilots( pilot_db=pilot_db, pilot_stamps=pilot_stamps ) - # If we found pilots from the list, this means some pilots already exists + # If we found pilots from the list, this means some pilots already exist if len(existing_pilots) > 0: found_keys = {pilot["PilotStamp"] for pilot in existing_pilots} From 02b89c0e7a53b501e1b5eb4e57630c5b77570280 Mon Sep 17 00:00:00 2001 From: Robin Van de Merghel Date: Fri, 13 Jun 2025 13:41:30 +0200 Subject: [PATCH 05/11] feat: Add pilot registration (secret-exchange) --- .../src/diracx/client/patches/pilots/aio.py | 3 +- .../diracx/client/patches/pilots/common.py | 4 +- .../src/diracx/client/patches/pilots/sync.py | 3 +- diracx-core/src/diracx/core/exceptions.py | 15 + diracx-core/src/diracx/core/models.py | 94 +++- diracx-core/src/diracx/core/settings.py | 2 + diracx-core/src/diracx/core/utils.py | 32 +- diracx-db/src/diracx/db/sql/pilots/db.py | 155 +++++- diracx-db/src/diracx/db/sql/pilots/schema.py | 30 ++ diracx-db/src/diracx/db/sql/utils/__init__.py | 7 +- .../src/diracx/db/sql/utils/functions.py | 4 + diracx-db/tests/pilots/test_pilot_auth.py | 166 +++++++ diracx-db/tests/pilots/utils.py | 173 ++++++- diracx-logic/src/diracx/logic/auth/token.py | 18 +- diracx-logic/src/diracx/logic/auth/utils.py | 16 +- diracx-logic/src/diracx/logic/pilots/auth.py | 413 ++++++++++++++++ .../src/diracx/logic/pilots/management.py | 44 +- diracx-logic/src/diracx/logic/pilots/query.py | 65 ++- diracx-routers/pyproject.toml | 1 + .../src/diracx/routers/auth/__init__.py | 2 + .../src/diracx/routers/auth/pilots.py | 107 ++++ .../src/diracx/routers/auth/token.py | 3 +- .../routers/pilot_resources/__init__.py | 23 + .../diracx/routers/pilot_resources/util.py | 32 ++ .../src/diracx/routers/pilots/__init__.py | 2 +- .../src/diracx/routers/pilots/management.py | 103 +++- .../src/diracx/routers/utils/pilots.py | 73 +++ .../src/diracx/routers/utils/users.py | 56 +-- diracx-routers/tests/auth/test_standard.py | 1 + .../tests/pilots/test_pilot_auth.py | 456 ++++++++++++++++++ .../tests/pilots/test_pilot_creation.py | 165 +++++++ diracx-testing/src/diracx/testing/utils.py | 21 + 32 files changed, 2213 insertions(+), 76 deletions(-) create mode 100644 diracx-db/tests/pilots/test_pilot_auth.py create mode 100644 diracx-logic/src/diracx/logic/pilots/auth.py create mode 100644 diracx-routers/src/diracx/routers/auth/pilots.py create mode 100644 diracx-routers/src/diracx/routers/pilot_resources/__init__.py create mode 100644 diracx-routers/src/diracx/routers/pilot_resources/util.py create mode 100644 diracx-routers/src/diracx/routers/utils/pilots.py create mode 100644 diracx-routers/tests/pilots/test_pilot_auth.py diff --git a/diracx-client/src/diracx/client/patches/pilots/aio.py b/diracx-client/src/diracx/client/patches/pilots/aio.py index ac533a67c..56d278a1f 100644 --- a/diracx-client/src/diracx/client/patches/pilots/aio.py +++ b/diracx-client/src/diracx/client/patches/pilots/aio.py @@ -16,6 +16,7 @@ from azure.core.tracing.decorator_async import distributed_trace_async from ..._generated.aio.operations._operations import PilotsOperations as _PilotsOperations +from ..._generated.models._models import PilotCredentialsInfo from .common import ( make_search_body, make_summary_body, @@ -43,7 +44,7 @@ async def summary(self, **kwargs: Unpack[SummaryKwargs]) -> list[dict[str, Any]] return await super().summary(**make_summary_body(**kwargs)) @distributed_trace_async - async def add_pilot_stamps(self, **kwargs: Unpack[AddPilotStampsKwargs]) -> None: + async def add_pilot_stamps(self, **kwargs: Unpack[AddPilotStampsKwargs]) -> list[PilotCredentialsInfo] | None: """TODO""" return await super().add_pilot_stamps(**make_add_pilot_stamps_body(**kwargs)) diff --git a/diracx-client/src/diracx/client/patches/pilots/common.py b/diracx-client/src/diracx/client/patches/pilots/common.py index 3f5ec8c4b..258bc42f8 100644 --- a/diracx-client/src/diracx/client/patches/pilots/common.py +++ b/diracx-client/src/diracx/client/patches/pilots/common.py @@ -99,6 +99,8 @@ class AddPilotStampsBody(TypedDict, total=False): pilot_references: dict[str, str] pilot_status: PilotStatus vo: str + generate_secrets: bool + pilot_secret_use_count_max: int | None class AddPilotStampsKwargs(AddPilotStampsBody, ResponseExtra): ... @@ -112,7 +114,7 @@ def make_add_pilot_stamps_body(**kwargs: Unpack[AddPilotStampsKwargs]) -> Underl for key in AddPilotStampsBody.__optional_keys__: if key not in kwargs: continue - key = cast(Literal["pilot_stamps", "grid_type", "grid_site", "pilot_references", "pilot_status", "vo"], key) + key = cast(Literal["pilot_stamps", "grid_type", "grid_site", "pilot_references", "pilot_status", "vo", "generate_secrets", "pilot_secret_use_count_max"], key) value = kwargs.pop(key) if value is not None: body[key] = value diff --git a/diracx-client/src/diracx/client/patches/pilots/sync.py b/diracx-client/src/diracx/client/patches/pilots/sync.py index 744cee161..e3059013b 100644 --- a/diracx-client/src/diracx/client/patches/pilots/sync.py +++ b/diracx-client/src/diracx/client/patches/pilots/sync.py @@ -16,6 +16,7 @@ from azure.core.tracing.decorator import distributed_trace from ..._generated.operations._operations import PilotsOperations as _PilotsOperations +from ..._generated.models._models import PilotCredentialsInfo from .common import ( make_search_body, make_summary_body, @@ -43,7 +44,7 @@ def summary(self, **kwargs: Unpack[SummaryKwargs]) -> list[dict[str, Any]]: return super().summary(**make_summary_body(**kwargs)) @distributed_trace - def add_pilot_stamps(self, **kwargs: Unpack[AddPilotStampsKwargs]) -> None: + def add_pilot_stamps(self, **kwargs: Unpack[AddPilotStampsKwargs]) -> list[PilotCredentialsInfo] | None: """TODO""" return super().add_pilot_stamps(**make_add_pilot_stamps_body(**kwargs)) diff --git a/diracx-core/src/diracx/core/exceptions.py b/diracx-core/src/diracx/core/exceptions.py index 19d8d5a41..f5af64322 100644 --- a/diracx-core/src/diracx/core/exceptions.py +++ b/diracx-core/src/diracx/core/exceptions.py @@ -99,6 +99,9 @@ def __init__(self, job_id, detail: str = ""): ) +class BadTokenError(DiracError): ... + + class NotReadyError(DiracError): """Tried to access a value which is asynchronously loaded but not yet available.""" @@ -113,3 +116,15 @@ class PilotAlreadyExistsError(DiracError): class PilotAlreadyAssociatedWithJobError(DiracError): """We can't associate a pilot with the same job twice.""" + + +class SecretHasExpiredError(DiracError): + """If a secret expired.""" + + +class SecretNotFoundError(DiracError): + """If a secret not found.""" + + +class BadPilotCredentialsError(DiracError): + """If a pilot tries to auth with another pilot's credentials.""" diff --git a/diracx-core/src/diracx/core/models.py b/diracx-core/src/diracx/core/models.py index 18144fc38..8f367550a 100644 --- a/diracx-core/src/diracx/core/models.py +++ b/diracx-core/src/diracx/core/models.py @@ -5,12 +5,15 @@ from __future__ import annotations +import uuid as std_uuid from datetime import datetime -from enum import StrEnum -from typing import Literal, Optional +from enum import StrEnum, auto +from typing import Any, Literal, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, GetCoreSchemaHandler, GetJsonSchemaHandler +from pydantic_core import CoreSchema, core_schema from typing_extensions import TypedDict +from uuid_utils import UUID as _UUID class ScalarSearchOperator(StrEnum): @@ -37,7 +40,7 @@ class ScalarSearchSpec(TypedDict): class VectorSearchSpec(TypedDict): parameter: str operator: VectorSearchOperator - values: list[str] | list[int] + values: list[str] | list[int] | list[bytes] SearchSpec = ScalarSearchSpec | VectorSearchSpec @@ -230,6 +233,29 @@ class SandboxUploadResponse(BaseModel): fields: dict[str, str] = {} +class UUID(_UUID): + """Subclass of uuid_utils.UUID to add pydantic support.""" + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> CoreSchema: + """Use the stdlib uuid.UUID schema for validation and serialization.""" + std_schema = handler(std_uuid.UUID) + + def to_uuid_utils(u: std_uuid.UUID) -> UUID: + return cls(str(u)) + + return core_schema.no_info_after_validator_function(to_uuid_utils, std_schema) + + @classmethod + def __get_pydantic_json_schema__( + cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler + ) -> dict[str, Any]: + """Return the stdlib uuid.UUID schema for JSON serialization.""" + return handler(core_schema) + + class GrantType(StrEnum): """Grant types for OAuth2.""" @@ -264,9 +290,14 @@ class OpenIDConfiguration(TypedDict): code_challenge_methods_supported: list[str] -class TokenPayload(TypedDict): +class BaseTokenPayload(TypedDict): + """This class helps having pilot and user tokens without code duplication.""" + jti: str exp: datetime + + +class TokenPayload(BaseTokenPayload): dirac_policies: dict @@ -359,3 +390,56 @@ class PilotStatus(StrEnum): ABORTED = "Aborted" #: Cannot get information about the pilot status: UNKNOWN = "Unknown" + + +class PilotSecretConstraints(TypedDict, total=False): + VOs: list[str] # Authorize only a list of VOs + PilotStamps: list[str] # Authorize only a list of stamps + Sites: list[str] # Authorize only a list of sites + # ... + # We can add constraints here + + +class TokenType(StrEnum): + # Pilot token + PILOT_TOKEN = auto() + # User token + USER_TOKEN = auto() + + +class PilotSecretsInfo(BaseModel): + pilot_secret: str + pilot_secret_expires_in: int + + +class PilotAccessTokenPayload(BaseTokenPayload): + sub: str + vo: str + iss: str + pilot_stamp: str + + +class PilotInfo(BaseModel): + pilot_stamp: str + vo: str + sub: str + + +class PilotRefreshTokenPayload(BaseTokenPayload): + legacy_exchange: bool + + +class PilotCredentialsInfo(PilotSecretsInfo): + pilot_stamp: str + + +class PilotAuthCredentials(TypedDict): + pilot_stamp: str + pilot_secret: str + + +class VacuumPilotAuth(PilotAuthCredentials): + vo: str + grid_type: str + grid_site: str + status: str diff --git a/diracx-core/src/diracx/core/settings.py b/diracx-core/src/diracx/core/settings.py index ef11459f8..9cd946f49 100644 --- a/diracx-core/src/diracx/core/settings.py +++ b/diracx-core/src/diracx/core/settings.py @@ -163,6 +163,8 @@ class AuthSettings(ServiceSettingsBase): token_allowed_algorithms: list[str] = ["RS256", "EdDSA"] # noqa: S105 access_token_expire_minutes: int = 20 refresh_token_expire_minutes: int = 60 + pilot_secret_expire_seconds: int = 3600 + pilot_refresh_token_expire_hours: int = 168 available_properties: set[SecurityProperty] = Field( default_factory=SecurityProperty.available_properties diff --git a/diracx-core/src/diracx/core/utils.py b/diracx-core/src/diracx/core/utils.py index ff7309af5..5c92fe7f4 100644 --- a/diracx-core/src/diracx/core/utils.py +++ b/diracx-core/src/diracx/core/utils.py @@ -1,5 +1,7 @@ from __future__ import annotations +from uuid import UUID + __all__ = [ "dotenv_files_from_environment", "serialize_credentials", @@ -19,7 +21,7 @@ from concurrent.futures import Future, ThreadPoolExecutor, wait from datetime import datetime, timedelta, timezone from pathlib import Path -from typing import Any, AsyncIterable, TypeVar +from typing import Any, AsyncIterable, Mapping, TypeVar, cast from cachetools import Cache, TTLCache @@ -271,3 +273,31 @@ async def batched_async( if strict and len(batch) != n: raise ValueError("batched(): incomplete batch") yield tuple(batch) + + +def extract_timestamp_from_uuid7(uuid_str: str) -> datetime: + u = UUID(uuid_str) + ts_bytes = u.bytes[0:6] # First 48 bits = timestamp in ms + timestamp_ms = int.from_bytes(ts_bytes, byteorder="big") + # Convert into seconds then to datetime + return datetime.fromtimestamp(timestamp_ms / 1000, timezone.utc) + + +T_DICTS = TypeVar("T_DICTS", bound=Mapping[str, Any]) + + +def recursive_dict_merge(x: T_DICTS, y: T_DICTS) -> T_DICTS: + result: dict[str, Any] = dict(x) + + for k, v in y.items(): + if k in result: + if isinstance(result[k], dict) and isinstance(v, dict): + result[k] = recursive_dict_merge(result[k], v) + elif isinstance(result[k], list) and isinstance(v, list): + result[k] = result[k] + v + else: + result[k] = v + else: + result[k] = v + + return cast(T_DICTS, result) diff --git a/diracx-db/src/diracx/db/sql/pilots/db.py b/diracx-db/src/diracx/db/sql/pilots/db.py index 0bfb32e07..cb6a3fbbe 100644 --- a/diracx-db/src/diracx/db/sql/pilots/db.py +++ b/diracx-db/src/diracx/db/sql/pilots/db.py @@ -3,20 +3,24 @@ from datetime import datetime, timezone from typing import Any -from sqlalchemy import bindparam +from sqlalchemy import DateTime, bindparam from sqlalchemy.exc import IntegrityError from sqlalchemy.sql import delete, insert, update +from uuid_utils import uuid7 from diracx.core.exceptions import ( PilotAlreadyAssociatedWithJobError, PilotNotFoundError, + SecretNotFoundError, ) from diracx.core.models import ( PilotFieldsMapping, + PilotSecretConstraints, PilotStatus, SearchSpec, SortSpec, ) +from diracx.db.sql.utils.functions import utcnow from ..utils import ( BaseSQLDB, @@ -26,6 +30,7 @@ PilotAgents, PilotAgentsDBBase, PilotOutput, + PilotSecrets, ) @@ -118,6 +123,31 @@ async def add_jobs_to_pilot(self, job_to_pilot_mapping: list[dict[str, Any]]): "Engine Specific error not caught" + str(e) ) from e + async def insert_unique_secrets( + self, + hashed_secrets: list[bytes], + secret_global_use_count_max: int | None = 1, + secret_constraints: dict[bytes, PilotSecretConstraints] = {}, + ): + """Bulk insert secrets. + + Raises: + - NotImplementedError if we have an IntegrityError not caught + + """ + values = [ + { + "SecretUUID": str(uuid7()), + "SecretRemainingUseCount": secret_global_use_count_max, + "HashedSecret": hashed_secret, + "SecretConstraints": secret_constraints.get(hashed_secret, {}), + } + for hashed_secret in hashed_secrets + ] + + stmt = insert(PilotSecrets).values(values) + await self.conn.execute(stmt) + # ----------------------------- Delete Functions ----------------------------- async def delete_pilots(self, pilot_ids: list[int]): @@ -140,6 +170,23 @@ async def delete_pilot_logs(self, pilot_ids: list[int]): await self.conn.execute(stmt) + async def delete_secrets(self, secret_uuids: list[str]): + """Bulk delete secrets. + + Raises SecretNotFoundError if one of the secret was not found. + """ + stmt = delete(PilotSecrets).where(PilotSecrets.secret_uuid.in_(secret_uuids)) + + res = await self.conn.execute(stmt) + + if res.rowcount != len(secret_uuids): + raise SecretNotFoundError( + "At least one of the secret has not been deleted." + ) + + # We NEED to commit here, because we will raise an error after this function + await self.conn.commit() + # ----------------------------- Update Functions ----------------------------- async def update_pilot_fields( @@ -194,6 +241,91 @@ async def update_pilot_fields( if res.rowcount != len(pilot_stamps_to_fields_mapping): raise PilotNotFoundError("at least one of the given pilot does not exist.") + async def update_pilot_secret_use_time(self, secret_uuid: str) -> None: + """Updates when a pilot uses a secret. + + Raises PilotNotFoundError if the pilot does not exist + + """ + # Prepare the update statement + stmt = ( + update(PilotSecrets) + .values( + pilot_secret_use_date=utcnow(), + secret_remaining_use_count=PilotSecrets.secret_remaining_use_count - 1, + ) + .where(PilotSecrets.secret_uuid == secret_uuid) + ) + + # Execute the update using the connection + res = await self.conn.execute(stmt) + + if res.rowcount == 0: + raise SecretNotFoundError("Unknown secret") + + async def update_pilot_secrets_constraints( + self, hashed_secrets_to_pilot_stamps_mapping: list[dict[str, Any]] + ): + """Bulk associate pilots with secrets by updating theirs constraints. + + Important: We have to provide the updated constraints. + + Raises: + - PilotNotFoundError if one of the pilot does not exist + - NotImplementedError if at least of the pilot + + """ + # Better to give as a parameter pilot to secret associations, rather than associating here. + + stmt = ( + update(PilotSecrets) + .where(PilotSecrets.hashed_secret == bindparam("PilotHashedSecret")) + .values({"SecretConstraints": bindparam("PilotSecretConstraints")}) + ) + + try: + await self.conn.execute(stmt, hashed_secrets_to_pilot_stamps_mapping) + except IntegrityError as e: + if "foreign key" in str(e.orig).lower(): + raise SecretNotFoundError( + detail="at least one of these secrets does not exist", + ) from e + raise NotImplementedError(f"This error is not caught: {str(e.orig)}") from e + + async def set_secret_expirations( + self, secret_uuids: list[str], pilot_secret_expiration_dates: list[DateTime] + ): + """Bulk set expiration dates to secrets. + + Raises: + - SecretNotFoundError if one of the secret_uuid is not associated with a secret. + - NotImplementedError if a integrity error is not caught. + - + + """ + values = [ + {"b_SecretUUID": secret_uuid, "SecretExpirationDate": pilot_secret} + for secret_uuid, pilot_secret in zip( + secret_uuids, pilot_secret_expiration_dates + ) + ] + + # Prepare the update statement + stmt = ( + update(PilotSecrets) + .where(PilotSecrets.secret_uuid == bindparam("b_SecretUUID")) + .values({"SecretExpirationDate": bindparam("SecretExpirationDate")}) + ) + + try: + await self.conn.execute(stmt, values) + except IntegrityError as e: + if "foreign key" in str(e.orig).lower(): + raise SecretNotFoundError( + detail="at least one of these secrets does not exist", + ) from e + raise NotImplementedError(f"This error is not caught: {str(e.orig)}") from e + # ----------------------------- Search Functions ----------------------------- async def search_pilots( @@ -238,6 +370,27 @@ async def search_pilot_to_job_mapping( page=page, ) + async def search_secrets( + self, + parameters: list[str] | None, + search: list[SearchSpec], + sorts: list[SortSpec], + *, + distinct: bool = False, + per_page: int = 100, + page: int | None = None, + ) -> tuple[int, list[dict[Any, Any]]]: + """Search for secrets in the database.""" + return await self._search( + table=PilotSecrets, + parameters=parameters, + search=search, + sorts=sorts, + distinct=distinct, + per_page=per_page, + page=page, + ) + async def pilot_summary( self, group_by: list[str], search: list[SearchSpec] ) -> list[dict[str, str | int]]: diff --git a/diracx-db/src/diracx/db/sql/pilots/schema.py b/diracx-db/src/diracx/db/sql/pilots/schema.py index af087f1f8..4ad4c9cb3 100644 --- a/diracx-db/src/diracx/db/sql/pilots/schema.py +++ b/diracx-db/src/diracx/db/sql/pilots/schema.py @@ -1,12 +1,17 @@ from __future__ import annotations from sqlalchemy import ( + BINARY, + JSON, DateTime, Double, Index, Integer, + SmallInteger, String, Text, + UniqueConstraint, + Uuid, ) from sqlalchemy.orm import declarative_base @@ -61,3 +66,28 @@ class PilotOutput(PilotAgentsDBBase): pilot_id = Column("PilotID", Integer, primary_key=True) std_output = Column("StdOutput", Text) std_error = Column("StdError", Text) + + +class PilotSecrets(PilotAgentsDBBase): + __tablename__ = "PilotSecrets" + + secret_uuid = Column("SecretUUID", Uuid(as_uuid=False), primary_key=True) + + hashed_secret = Column("HashedSecret", BINARY(32)) + # Global count + # Null: Infinite use + secret_remaining_use_count = NullColumn( + "SecretRemainingUseCount", SmallInteger, default=1 + ) + secret_expiration_date = NullColumn("SecretExpirationDate", DateTime(timezone=True)) + # To authorize only specific pilots to access a secret + # The constraint format follows diracx.code.models.PilotSecretConstraints + secret_constraints = NullColumn("SecretConstraints", JSON) + + # If a date is set, then it used a secret (acts also like a "PilotUsedSecret" field) + pilot_secret_use_date = NullColumn("PilotSecretUseDate", DateTime(timezone=True)) + + __table_args__ = ( + UniqueConstraint("HashedSecret", name="uq_hashed_secret"), + Index("HashedSecret", "HashedSecret"), + ) diff --git a/diracx-db/src/diracx/db/sql/utils/__init__.py b/diracx-db/src/diracx/db/sql/utils/__init__.py index 53b3f3c96..9834bd965 100644 --- a/diracx-db/src/diracx/db/sql/utils/__init__.py +++ b/diracx-db/src/diracx/db/sql/utils/__init__.py @@ -7,11 +7,7 @@ apply_search_filters, apply_sort_constraints, ) -from .functions import ( - hash, - substract_date, - utcnow, -) +from .functions import hash, raw_hash, substract_date, utcnow from .types import Column, DateNowColumn, EnumBackedBool, EnumColumn, NullColumn __all__ = ( @@ -25,6 +21,7 @@ "EnumColumn", "hash", "NullColumn", + "raw_hash", "substract_date", "SQLDBUnavailableError", "utcnow", diff --git a/diracx-db/src/diracx/db/sql/utils/functions.py b/diracx-db/src/diracx/db/sql/utils/functions.py index 34cb2a0da..c911324b5 100644 --- a/diracx-db/src/diracx/db/sql/utils/functions.py +++ b/diracx-db/src/diracx/db/sql/utils/functions.py @@ -140,3 +140,7 @@ def substract_date(**kwargs: float) -> datetime: def hash(code: str): return hashlib.sha256(code.encode()).hexdigest() + + +def raw_hash(code: str): + return hashlib.sha256(code.encode()).digest() diff --git a/diracx-db/tests/pilots/test_pilot_auth.py b/diracx-db/tests/pilots/test_pilot_auth.py new file mode 100644 index 000000000..99ea0c58e --- /dev/null +++ b/diracx-db/tests/pilots/test_pilot_auth.py @@ -0,0 +1,166 @@ +from __future__ import annotations + +from datetime import timedelta +from random import shuffle +from typing import AsyncGenerator, Generator + +import freezegun +import pytest +import sqlalchemy + +from diracx.core.exceptions import ( + BadPilotCredentialsError, + PilotNotFoundError, + SecretHasExpiredError, + SecretNotFoundError, +) +from diracx.db.sql.pilots.db import PilotAgentsDB +from diracx.db.sql.utils.functions import raw_hash +from diracx.testing.time import mock_sqlite_time + +from .utils import ( + add_secrets_and_time, # noqa: F401 + add_stamps, # noqa: F401 + verify_pilot_secret, +) + + +@pytest.fixture +async def pilot_db() -> AsyncGenerator[PilotAgentsDB, None]: + agents_db = PilotAgentsDB("sqlite+aiosqlite:///:memory:") + async with agents_db.engine_context(): + sqlalchemy.event.listen( + agents_db.engine.sync_engine, "connect", mock_sqlite_time + ) + async with agents_db.engine.begin() as conn: + await conn.run_sync(agents_db.metadata.create_all) + yield agents_db + + +@pytest.fixture() +def frozen_time() -> Generator[freezegun.FreezeGun, None]: + with freezegun.freeze_time("2012-01-14") as ft: + yield ft + + +@pytest.mark.parametrize("secret_duration_sec", [10]) +@pytest.mark.asyncio +async def test_create_pilot_and_verify_secret( + pilot_db: PilotAgentsDB, + add_secrets_and_time, # noqa: F811 + frozen_time: freezegun.FreezeGun, +): + # Add pilots + result = add_secrets_and_time + stamps = result["stamps"] + secrets = result["secrets"] + + pairs = list(zip(stamps, secrets)) + # Shuffle it to prove that credentials are well associated + shuffle(pairs) + + async with pilot_db as pilot_db: + for stamp, secret in pairs: + await verify_pilot_secret( + pilot_db=pilot_db, + pilot_stamp=stamp, + hashed_secret=raw_hash(secret), + frozen_time=frozen_time, + ) + + with pytest.raises(SecretNotFoundError): + await verify_pilot_secret( + pilot_db=pilot_db, + pilot_stamp=stamps[0], + hashed_secret=raw_hash("I love stawberries :)"), + frozen_time=frozen_time, + ) + + with pytest.raises(PilotNotFoundError): + await verify_pilot_secret( + pilot_db=pilot_db, + pilot_stamp="I am a spider", + hashed_secret=raw_hash(secrets[0]), + frozen_time=frozen_time, + ) + + +@pytest.mark.parametrize("secret_duration_sec", [1]) +@pytest.mark.asyncio +async def test_create_pilot_and_verify_secret_with_delay( + pilot_db: PilotAgentsDB, + add_secrets_and_time, # noqa: F811 + frozen_time: freezegun.FreezeGun, +): + # Add pilots + result = add_secrets_and_time + stamps = result["stamps"] + secrets = result["secrets"] + + # Move forward few minutes + frozen_time.tick(delta=timedelta(minutes=5)) + + async with pilot_db as pilot_db: + with pytest.raises(SecretHasExpiredError): + await verify_pilot_secret( + pilot_db=pilot_db, + pilot_stamp=stamps[0], + hashed_secret=raw_hash(secrets[0]), + frozen_time=frozen_time, + ) + + +@pytest.mark.parametrize("secret_duration_sec", [10]) +@pytest.mark.asyncio +async def test_create_pilot_and_verify_secret_too_much_secret_use( + pilot_db: PilotAgentsDB, + add_secrets_and_time, # noqa: F811 + frozen_time: freezegun.FreezeGun, +): + # Add pilots + result = add_secrets_and_time + stamps = result["stamps"] + secrets = result["secrets"] + + # First login, should work + async with pilot_db as pilot_db: + await verify_pilot_secret( + pilot_db=pilot_db, + pilot_stamp=stamps[0], + hashed_secret=raw_hash(secrets[0]), + frozen_time=frozen_time, + ) + + # Second login, should not work because maxed out at 1 try + # If the foreign key works, we should have "SecretNotFoundError" + with pytest.raises(SecretNotFoundError): + await verify_pilot_secret( + pilot_db=pilot_db, + pilot_stamp=stamps[0], + hashed_secret=raw_hash(secrets[0]), + frozen_time=frozen_time, + ) + + +@pytest.mark.parametrize("secret_duration_sec", [10]) +@pytest.mark.asyncio +async def test_create_pilot_and_login_with_bad_secret( + pilot_db: PilotAgentsDB, + add_secrets_and_time, # noqa: F811 + frozen_time: freezegun.FreezeGun, +): + # Add pilots + result = add_secrets_and_time + stamps = result["stamps"] + secrets = result["secrets"] + + async with pilot_db as pilot_db: + # Pilot1 will try to login with every other pilots's secret + for secret in secrets[1:]: + with pytest.raises(BadPilotCredentialsError): + await verify_pilot_secret( + pilot_db=pilot_db, + pilot_stamp=stamps[0], + hashed_secret=raw_hash(secret), + frozen_time=frozen_time, + ) diff --git a/diracx-db/tests/pilots/utils.py b/diracx-db/tests/pilots/utils.py index 793310d0d..df73fd3ec 100644 --- a/diracx-db/tests/pilots/utils.py +++ b/diracx-db/tests/pilots/utils.py @@ -1,12 +1,20 @@ from __future__ import annotations -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from typing import Any +import freezegun import pytest from sqlalchemy import update +from diracx.core.exceptions import ( + BadPilotCredentialsError, + PilotNotFoundError, + SecretHasExpiredError, + SecretNotFoundError, +) from diracx.core.models import ( + PilotSecretConstraints, ScalarSearchOperator, ScalarSearchSpec, VectorSearchOperator, @@ -14,6 +22,7 @@ ) from diracx.db.sql.pilots.db import PilotAgentsDB from diracx.db.sql.pilots.schema import PilotAgents +from diracx.db.sql.utils.functions import raw_hash MAIN_VO = "lhcb" N = 100 @@ -24,6 +33,8 @@ async def get_pilots_by_stamp( pilot_db: PilotAgentsDB, pilot_stamps: list[str], parameters: list[str] = [] ) -> list[dict[Any, Any]]: + if parameters: + parameters.append("PilotStamp") _, pilots = await pilot_db.search_pilots( parameters=parameters, search=[ @@ -61,6 +72,56 @@ async def get_pilot_jobs_ids_by_pilot_id( return [job["JobID"] for job in jobs] +async def get_secrets_by_hashed_secrets( + pilot_db: PilotAgentsDB, hashed_secrets: list[bytes], parameters: list[str] = [] +) -> list[dict[Any, Any]]: + _, secrets = await pilot_db.search_secrets( + parameters=parameters, + search=[ + VectorSearchSpec( + parameter="HashedSecret", + operator=VectorSearchOperator.IN, + values=hashed_secrets, + ) + ], + sorts=[], + distinct=True, + per_page=1000, + ) + + return secrets + + +async def get_secrets_by_uuid( + pilot_db: PilotAgentsDB, secret_uuids: list[str], parameters: list[str] = [] +) -> list[dict[Any, Any]]: + parameters.append("SecretUUID") # To avoid bug later on `found_keys = ...` + + _, secrets = await pilot_db.search_secrets( + parameters=parameters, + search=[ + VectorSearchSpec( + parameter="SecretUUID", + operator=VectorSearchOperator.IN, + values=secret_uuids, + ) + ], + sorts=[], + distinct=True, + per_page=1000, + ) + + # Custom handling, to see which secret_uuid does not exist + # TODO: Add missing in the error + found_keys = {row["SecretUUID"] for row in secrets} + missing = set(secret_uuids) - found_keys + + if missing: + raise SecretNotFoundError(detail=str(missing)) + + return secrets + + # ------------ Creating data ------------ @@ -149,3 +210,113 @@ async def create_old_pilots_environment(pilot_db, create_timed_pilots): await get_pilots_by_stamp(pilot_db, [non_aborted_very_old[0]["PilotStamp"]]) return non_aborted_recent, aborted_recent, non_aborted_very_old, aborted_very_old + + +@pytest.fixture +async def add_secrets_and_time( + pilot_db, add_stamps, secret_duration_sec, frozen_time: freezegun.FreezeGun +): + # Retrieve the stamps from the add_stamps fixture + stamps = [pilot["PilotStamp"] for pilot in await add_stamps()] + + # Add a VO restriction as well as association with a specific pilot + secrets = [f"AW0nd3rfulS3cr3t_{str(i)}" for i in range(len(stamps))] + hashed_secrets = [raw_hash(secret) for secret in secrets] + constraints = { + hashed_secret: PilotSecretConstraints(VOs=[MAIN_VO], PilotStamps=[stamp]) + for hashed_secret, stamp in zip(hashed_secrets, stamps) + } + + async with pilot_db as pilot_db: + # Add creds + await pilot_db.insert_unique_secrets( + hashed_secrets=hashed_secrets, secret_constraints=constraints + ) + + # Associate with pilot + secrets_obj = await get_secrets_by_hashed_secrets(pilot_db, hashed_secrets) + + assert len(secrets_obj) == len(hashed_secrets) == len(stamps) + + # extract_timestamp_from_uuid7(secret_obj["SecretUUID"]) does not work here + # See #548 + expiration_date = [ + datetime.now(timezone.utc) + timedelta(seconds=secret_duration_sec) + for secret_obj in secrets_obj + ] + + await pilot_db.set_secret_expirations( + secret_uuids=[secret_obj["SecretUUID"] for secret_obj in secrets_obj], + pilot_secret_expiration_dates=expiration_date, + ) + + # Return both non-hashed secrets and stamps + return {"stamps": stamps, "secrets": secrets} + + +# ------------ Verifying data ------------ + + +async def verify_pilot_secret( + pilot_stamp: str, + pilot_db: PilotAgentsDB, + hashed_secret: bytes, + frozen_time: freezegun.FreezeGun, +) -> None: + # 1. Get the pilot + pilots = await get_pilots_by_stamp( + pilot_db=pilot_db, + pilot_stamps=[pilot_stamp], + parameters=["VO", "PilotStamp"], + ) + if len(pilots) == 0: + raise PilotNotFoundError() + pilot = dict(pilots[0]) + + # 2. Get the secret itself + secrets = await get_secrets_by_hashed_secrets( + pilot_db=pilot_db, hashed_secrets=[hashed_secret] + ) + if len(secrets) == 0: + raise SecretNotFoundError(str(hashed_secret)) + secret = secrets[0] + secret_uuid = secret["SecretUUID"] + secret_constraints = PilotSecretConstraints(**secret["SecretConstraints"]) + + # 3. Check the constraints + await check_pilot_constraints(pilot=pilot, secret_constraints=secret_constraints) + + # 4. Check if the secret is expired + now = datetime.now(tz=timezone.utc) + # Convert the timezone, TODO: Change with #454: https://github.com/DIRACGrid/diracx/pull/454 + expiration = secret["SecretExpirationDate"].replace(tzinfo=timezone.utc) + if expiration < now: + await pilot_db.delete_secrets([secret_uuid]) + + raise SecretHasExpiredError( + f"expiration_date {secret['SecretExpirationDate']}", + ) + + # 5. Now the pilot is authorized, change when the pilot used the secret. + await pilot_db.update_pilot_secret_use_time( + secret_uuid=secret_uuid, + ) + + # 6. Delete the secret if its count attained the secret_global_use_count_max + if secret["SecretRemainingUseCount"]: + # If we use it another time, SecretRemainingUseCount will be equal to 0 so we can delete it + if secret["SecretRemainingUseCount"] == 1: + await pilot_db.delete_secrets([secret_uuid]) + + +async def check_pilot_constraints( + pilot: dict[str, Any], secret_constraints: PilotSecretConstraints +): + key_map = {"VOs": "VO", "PilotStamps": "PilotStamp", "Sites": "Site"} + + for constraint_key, pilot_key in key_map.items(): + allowed_values = secret_constraints.get(constraint_key) + if allowed_values: + pilot_value = pilot.get(pilot_key) + if pilot_value is None or pilot_value not in allowed_values: + raise BadPilotCredentialsError() diff --git a/diracx-logic/src/diracx/logic/auth/token.py b/diracx-logic/src/diracx/logic/auth/token.py index 1d3f924e2..d6b85a6aa 100644 --- a/diracx-logic/src/diracx/logic/auth/token.py +++ b/diracx-logic/src/diracx/logic/auth/token.py @@ -20,9 +20,10 @@ ) from diracx.core.models import ( AccessTokenPayload, + BaseTokenPayload, GrantType, RefreshTokenPayload, - TokenPayload, + TokenType, ) from diracx.core.properties import SecurityProperty from diracx.core.settings import AuthSettings @@ -76,9 +77,7 @@ async def get_oidc_token( legacy_exchange, refresh_token_expire_minutes, include_refresh_token, - ) = await get_oidc_token_info_from_refresh_flow( - refresh_token, auth_db, settings - ) + ) = await get_token_info_from_refresh_flow(refresh_token, auth_db, settings) else: raise NotImplementedError(f"Grant type not implemented {grant_type}") @@ -160,13 +159,16 @@ async def get_oidc_token_info_from_authorization_flow( return (oidc_token_info, scope) -async def get_oidc_token_info_from_refresh_flow( - refresh_token: str, auth_db: AuthDB, settings: AuthSettings +async def get_token_info_from_refresh_flow( + refresh_token: str, + auth_db: AuthDB, + settings: AuthSettings, + token_type: TokenType = TokenType.USER_TOKEN, ) -> tuple[dict, str, bool, float, bool]: """Get OIDC token information from the refresh token DB and check few parameters before returning it.""" # Decode the refresh token to get the JWT ID jti, exp, legacy_exchange = await verify_dirac_refresh_token( - refresh_token, settings + refresh_token, settings, token_type ) # Get some useful user information from the refresh token entry in the DB @@ -356,7 +358,7 @@ async def exchange_token( return access_payload, refresh_payload -def create_token(payload: TokenPayload, settings: AuthSettings) -> str: +def create_token(payload: BaseTokenPayload, settings: AuthSettings) -> str: """Create a JWT token with the given payload and settings.""" signing_key = None for key in settings.token_keystore.jwks.keys: diff --git a/diracx-logic/src/diracx/logic/auth/utils.py b/diracx-logic/src/diracx/logic/auth/utils.py index 184fc9f6a..ced150448 100644 --- a/diracx-logic/src/diracx/logic/auth/utils.py +++ b/diracx-logic/src/diracx/logic/auth/utils.py @@ -15,8 +15,13 @@ from uuid_utils import UUID from diracx.core.config.schema import Config -from diracx.core.exceptions import AuthorizationError, IAMClientError, IAMServerError -from diracx.core.models import GrantType +from diracx.core.exceptions import ( + AuthorizationError, + BadTokenError, + IAMClientError, + IAMServerError, +) +from diracx.core.models import GrantType, TokenType from diracx.core.properties import SecurityProperty from diracx.core.settings import AuthSettings @@ -208,6 +213,7 @@ def read_token( async def verify_dirac_refresh_token( refresh_token: str, settings: AuthSettings, + token_type: TokenType = TokenType.USER_TOKEN, ) -> tuple[UUID, float, bool]: """Verify dirac user token and return a UserInfo class Used for each API endpoint. @@ -216,6 +222,12 @@ async def verify_dirac_refresh_token( refresh_token, settings.token_keystore.jwks, settings.token_allowed_algorithms ) + if token_type == TokenType.USER_TOKEN and "dirac_policies" not in claims: + raise BadTokenError("This is not a user token.") + + if token_type == TokenType.PILOT_TOKEN and "dirac_policies" in claims: + raise BadTokenError("This is not a pilot token.") + return ( UUID(claims["jti"]), float(claims["exp"]), diff --git a/diracx-logic/src/diracx/logic/pilots/auth.py b/diracx-logic/src/diracx/logic/pilots/auth.py new file mode 100644 index 000000000..ec9ae2188 --- /dev/null +++ b/diracx-logic/src/diracx/logic/pilots/auth.py @@ -0,0 +1,413 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from secrets import token_hex +from typing import Any, cast + +from uuid_utils import uuid7 + +from diracx.core.exceptions import ( + BadPilotCredentialsError, + SecretHasExpiredError, +) +from diracx.core.models import ( + PilotAccessTokenPayload, + PilotAuthCredentials, + PilotRefreshTokenPayload, + PilotSecretConstraints, + PilotSecretsInfo, + TokenResponse, + TokenType, + VacuumPilotAuth, +) +from diracx.core.settings import AuthSettings +from diracx.core.utils import extract_timestamp_from_uuid7, recursive_dict_merge +from diracx.db.sql import AuthDB, PilotAgentsDB +from diracx.db.sql.utils.functions import raw_hash +from diracx.logic.auth.token import ( + create_token, + get_token_info_from_refresh_flow, + insert_refresh_token, +) +from diracx.logic.pilots.query import ( + get_pilots_by_stamp, + get_secrets_by_hashed_secrets, +) + + +async def create_raw_secrets( + n: int, + pilot_db: PilotAgentsDB, + settings: AuthSettings, + secret_constraint: PilotSecretConstraints, + pilot_secret_use_count_max: int | None = 1, + expiration_minutes: int | None = None, +) -> tuple[list[str], list[int]]: + # Get a random string + # Can be customized + random_secrets = [generate_pilot_secret() for _ in range(n)] + + hashed_secrets = [raw_hash(random_secret) for random_secret in random_secrets] + + secret_constraints = { + hashed_secret: secret_constraint for hashed_secret in hashed_secrets + } + + # Insert secrets + await pilot_db.insert_unique_secrets( + hashed_secrets=hashed_secrets, + secret_global_use_count_max=pilot_secret_use_count_max, + secret_constraints=secret_constraints, + ) + + secrets_added = await get_secrets_by_hashed_secrets( + pilot_db=pilot_db, + hashed_secrets=hashed_secrets, + parameters=["SecretUUID"], # For efficiency + ) + + # If we have millions of pilots to add, can take few seconds / minutes to add + expiration_dates = [ + extract_timestamp_from_uuid7(secret["SecretUUID"]) + + timedelta( + seconds=( + expiration_minutes * 60 + if expiration_minutes + else settings.pilot_secret_expire_seconds + ) + ) + for secret in secrets_added + ] + secret_uuids = [secret["SecretUUID"] for secret in secrets_added] + + # Helps compatibility between sql engines + await pilot_db.set_secret_expirations( + secret_uuids=secret_uuids, + pilot_secret_expiration_dates=expiration_dates, # type: ignore + ) + + expiration_dates_timestamps = [ + int(expire_date.timestamp()) for expire_date in expiration_dates + ] + + return random_secrets, expiration_dates_timestamps + + +async def create_secrets( + n: int, + pilot_db: PilotAgentsDB, + settings: AuthSettings, + secret_constraint: PilotSecretConstraints, + pilot_secret_use_count_max: int | None = 1, + expiration_minutes: int | None = None, +) -> list[PilotSecretsInfo]: + pilot_secrets, expiration_dates_timestamps = await create_raw_secrets( + n=n, + pilot_db=pilot_db, + settings=settings, + pilot_secret_use_count_max=pilot_secret_use_count_max, + expiration_minutes=expiration_minutes, + secret_constraint=secret_constraint, + ) + + return [ + PilotSecretsInfo( + pilot_secret=secret, + pilot_secret_expires_in=expires_in, + ) + for secret, expires_in in zip(pilot_secrets, expiration_dates_timestamps) + ] + + +async def update_secrets_constraints( + pilot_db: PilotAgentsDB, + secrets_to_constraints_dict: dict[str, PilotSecretConstraints], +): + # 1. Create a mapping that uses hashed_secret + # Modify the mapping to use hashed_secrets instead of secrets + hashed_secrets_to_pilot_stamps_dict = { + raw_hash(secret): constraints + for secret, constraints in secrets_to_constraints_dict.items() + } + # Now the dictionary follows: {b"": []} + # + # If we had a list like so : [{"b_...UUID": , "b_...Stamp": }], to update the JSON we would need + # to groupby it if we use multiple times the same secret (to modify the JSON by merging and not overriding) + + # 2. Get the secret ids to later associate them with pilots + # It also verifies that all secrets exist + secrets_obj = await get_secrets_by_hashed_secrets( + pilot_db=pilot_db, + hashed_secrets=list(hashed_secrets_to_pilot_stamps_dict.keys()), + parameters=["SecretConstraints"], # For efficiency, we don't need more info + ) + + # Mapping [ {"PilotHashedSecret": b"", "PilotSecretConstraints": {...}} ] + # This is useful to update inside the database, but it is not useful to merge the old JSON with the new one + hashed_secrets_to_pilot_stamps_mapping: list[dict[str, Any]] = [] + + # 3. Merge the constraints so that we don't loose the old ones + for secret_obj in secrets_obj: + # Get the current constraints and hashed_secret + secret_constraints = PilotSecretConstraints(**secret_obj["SecretConstraints"]) + hashed_secret = secret_obj["HashedSecret"] + + # Merge it with the given constraints + new_secret_constraints = hashed_secrets_to_pilot_stamps_dict[hashed_secret] + + hashed_secrets_to_pilot_stamps_mapping.append( + { + "PilotHashedSecret": hashed_secret, + "PilotSecretConstraints": recursive_dict_merge( + secret_constraints, new_secret_constraints + ), + } + ) + + await pilot_db.update_pilot_secrets_constraints( + hashed_secrets_to_pilot_stamps_mapping + ) + + +async def verify_pilot_credentials( + pilot_db: PilotAgentsDB, + auth_db: AuthDB, + settings: AuthSettings, + credentials: PilotAuthCredentials | VacuumPilotAuth, +) -> TokenResponse: + hashed_secret = raw_hash(credentials["pilot_secret"]) + + if "vo" in credentials and "grid_type" in credentials: + credentials = cast(VacuumPilotAuth, credentials) + # 1-bis. DIRAC's `dirac-admin-add-pilot` mechanic: + # If a pilot does not exist yet (vacuum case), we add it in the db + # In that special case, we need a VO, grid_type, etc. + await pilot_db.add_pilots( + pilot_stamps=[credentials["pilot_stamp"]], + vo=credentials["vo"], + grid_type=credentials["grid_type"], + grid_site=credentials["grid_site"], + status=credentials["status"], + ) + + # 1. Get the pilot + pilots = await get_pilots_by_stamp( + pilot_db=pilot_db, + pilot_stamps=[credentials["pilot_stamp"]], + parameters=["VO"], + allow_missing=False, + ) + pilot = dict(pilots[0]) + + # 2. Get the secret itself + secrets = await get_secrets_by_hashed_secrets( + pilot_db=pilot_db, hashed_secrets=[hashed_secret] + ) + secret = secrets[0] + secret_uuid = secret["SecretUUID"] + secret_constraints = PilotSecretConstraints(**secret["SecretConstraints"]) + + # 3. Check the constraints + await check_pilot_constraints(pilot=pilot, secret_constraints=secret_constraints) + + # 4. Check if the secret is expired + now = datetime.now(tz=timezone.utc) + # Convert the timezone, TODO: Change with #454: https://github.com/DIRACGrid/diracx/pull/454 + expiration = secret["SecretExpirationDate"].replace(tzinfo=timezone.utc) + if expiration < now: + await pilot_db.delete_secrets([secret_uuid]) + + raise SecretHasExpiredError( + detail=f"expiration_date{secret['SecretExpirationDate']}", + ) + + # 5. Now the pilot is authorized, change when the pilot used the secret. + await pilot_db.update_pilot_secret_use_time( + secret_uuid=secret_uuid, + ) + + # 6. Delete the secret if its count attained the secret_global_use_count_max + if secret["SecretRemainingUseCount"]: + # If we use it another time, SecretRemainingUseCount will be equal to 0 so we can delete it + if secret["SecretRemainingUseCount"] == 1: + await pilot_db.delete_secrets([secret_uuid]) + + # Get token, and serialize + access_token_payload, refresh_token_payload = await generate_pilot_tokens( + vo=pilot["VO"], + pilot_stamp=credentials["pilot_stamp"], + auth_db=auth_db, + settings=settings, + refresh_token=None, + ) + + return await serialize_tokens( + access_token_payload=access_token_payload, + refresh_token_payload=refresh_token_payload, + settings=settings, + ) + + +async def refresh_pilot_token( + pilot_stamp: str, + auth_db: AuthDB, + settings: AuthSettings, + pilot_db: PilotAgentsDB, + refresh_token: str | None = None, +) -> TokenResponse: + pilots = await get_pilots_by_stamp( + pilot_db=pilot_db, + pilot_stamps=[pilot_stamp], + parameters=["VO"], + allow_missing=False, + ) + pilot = pilots[0] # Semantic + vo = pilot["VO"] + + access_token_payload, refresh_token_payload = await generate_pilot_tokens( + vo=vo, + pilot_stamp=pilot_stamp, + auth_db=auth_db, + settings=settings, + refresh_token=refresh_token, + ) + + return await serialize_tokens( + access_token_payload=access_token_payload, + refresh_token_payload=refresh_token_payload, + settings=settings, + ) + + +async def check_pilot_constraints( + pilot: dict[str, Any], secret_constraints: PilotSecretConstraints +): + key_map = {"VOs": "VO", "PilotStamps": "PilotStamp", "Sites": "Site"} + + err = BadPilotCredentialsError() + + for constraint_key, pilot_key in key_map.items(): + expected = secret_constraints.get(constraint_key) + if expected is not None: + pilot_value = pilot.get(pilot_key) + if pilot_value is None: + raise err + + if isinstance(expected, list): + if pilot_value not in expected: + raise err + else: + if pilot_value != expected: + raise err + + +def generate_pilot_secret() -> str: + # Can change with time + return token_hex(32) + + +async def exchange_token( + scope: str, + sub: str, + vo: str, + pilot_stamp: str, + auth_db: AuthDB, + settings: AuthSettings, + legacy_exchange: bool, + include_refresh_token: bool, +) -> tuple[PilotAccessTokenPayload, PilotRefreshTokenPayload | None]: + """Method called to exchange the OIDC token for a DIRAC generated access token.""" + # Merge the VO with the stamp to get a sub + + creation_time = datetime.now(timezone.utc) + # Insert the refresh token with user details into the RefreshTokens table + # User details are needed to regenerate access tokens later + jti, creation_time = await insert_refresh_token( + auth_db=auth_db, + subject=sub, + scope=scope, + ) + + refresh_payload: PilotRefreshTokenPayload | None = None + if include_refresh_token: + refresh_payload = { + "jti": str(jti), + "exp": creation_time + + timedelta(hours=settings.pilot_refresh_token_expire_hours), + "legacy_exchange": legacy_exchange, + } + + # Generate access token payload + access_payload: PilotAccessTokenPayload = { + "sub": sub, + "vo": vo, + "iss": settings.token_issuer, + "jti": str(uuid7()), + # This field is redundant, but if later we change the sub, we won't need to change how we use the token + "pilot_stamp": pilot_stamp, + "exp": creation_time + timedelta(minutes=settings.access_token_expire_minutes), + } + + return access_payload, refresh_payload + + +async def generate_pilot_tokens( + vo: str, + pilot_stamp: str, + auth_db: AuthDB, + settings: AuthSettings, + refresh_token: str | None = None, +) -> tuple[PilotAccessTokenPayload, PilotRefreshTokenPayload | None]: + include_refresh_token = True + + if refresh_token is not None: + ( + pilot_info, + scope, + legacy_exchange, + _, + include_refresh_token, + ) = await get_token_info_from_refresh_flow( + refresh_token=refresh_token, + auth_db=auth_db, + settings=settings, + token_type=TokenType.PILOT_TOKEN, + ) + + sub = f"{vo}:{pilot_info['sub']}" + else: + # We don't need a user sub as before + # https://github.com/DIRACGrid/diracx/pull/421#issuecomment-2909087954 + # Same for the property, but it is useful as we store the scope (to detect pilots) + scope = f"vo:{vo} property:GenericPilot" + sub = f"{vo}:{pilot_stamp}" + legacy_exchange = False + + return await exchange_token( + scope=scope, + sub=sub, + vo=vo, + pilot_stamp=pilot_stamp, + auth_db=auth_db, + settings=settings, + legacy_exchange=legacy_exchange, + include_refresh_token=include_refresh_token, + ) + + +async def serialize_tokens( + access_token_payload: PilotAccessTokenPayload, + refresh_token_payload: PilotRefreshTokenPayload | None, + settings: AuthSettings, +): + access_token = create_token(payload=access_token_payload, settings=settings) + + refresh_token: str | None = None + if refresh_token_payload: + refresh_token = create_token(payload=refresh_token_payload, settings=settings) + + return TokenResponse( + access_token=access_token, + expires_in=settings.pilot_secret_expire_seconds, + refresh_token=refresh_token, + ) diff --git a/diracx-logic/src/diracx/logic/pilots/management.py b/diracx-logic/src/diracx/logic/pilots/management.py index a6256a742..a74012b74 100644 --- a/diracx-logic/src/diracx/logic/pilots/management.py +++ b/diracx-logic/src/diracx/logic/pilots/management.py @@ -3,9 +3,15 @@ from datetime import datetime, timedelta, timezone from diracx.core.exceptions import PilotAlreadyExistsError, PilotNotFoundError -from diracx.core.models import PilotFieldsMapping +from diracx.core.models import ( + PilotCredentialsInfo, + PilotFieldsMapping, + PilotSecretConstraints, +) +from diracx.core.settings import AuthSettings from diracx.db.sql import PilotAgentsDB +from .auth import create_raw_secrets, update_secrets_constraints from .query import ( get_outdated_pilots, get_pilot_ids_by_stamps, @@ -23,7 +29,10 @@ async def register_new_pilots( destination_site: str, status: str, pilot_job_references: dict[str, str] | None, -): + settings: AuthSettings, + generate_secrets: bool = True, + pilot_secret_use_count_max: int | None = None, +) -> list[PilotCredentialsInfo] | None: # [IMPORTANT] Check unicity of pilot stamps # If a pilot already exists, we raise an error (transaction will rollback) existing_pilots = await get_pilots_by_stamp( @@ -48,6 +57,37 @@ async def register_new_pilots( status=status, ) + if not generate_secrets: + return None + + pilot_secrets, expiration_dates_timestamps = await create_raw_secrets( + n=len(pilot_stamps), + pilot_db=pilot_db, + settings=settings, + pilot_secret_use_count_max=pilot_secret_use_count_max, + secret_constraint=PilotSecretConstraints(VOs=[vo]), + ) + + constraints = { + pilot_secret: PilotSecretConstraints(PilotStamps=[pilot_stamp], VOs=[vo]) + for pilot_secret, pilot_stamp in zip(pilot_secrets, pilot_stamps) + } + + await update_secrets_constraints( + pilot_db=pilot_db, secrets_to_constraints_dict=constraints + ) + + return [ + PilotCredentialsInfo( + pilot_stamp=pilot_stamp, + pilot_secret=secret, + pilot_secret_expires_in=expires_in, + ) + for pilot_stamp, secret, expires_in in zip( + pilot_stamps, pilot_secrets, expiration_dates_timestamps + ) + ] + async def delete_pilots( pilot_db: PilotAgentsDB, diff --git a/diracx-logic/src/diracx/logic/pilots/query.py b/diracx-logic/src/diracx/logic/pilots/query.py index b6cf504d7..7be9bd802 100644 --- a/diracx-logic/src/diracx/logic/pilots/query.py +++ b/diracx-logic/src/diracx/logic/pilots/query.py @@ -3,7 +3,7 @@ from datetime import datetime from typing import Any -from diracx.core.exceptions import PilotNotFoundError +from diracx.core.exceptions import PilotNotFoundError, SecretNotFoundError from diracx.core.models import ( PilotStatus, ScalarSearchOperator, @@ -87,7 +87,7 @@ async def get_pilots_by_stamp( if missing: raise PilotNotFoundError( - detail=str(missing), + detail=f"Pilot(s) not found: {str(missing)}", ) return pilots @@ -189,3 +189,64 @@ async def summary(pilot_db: PilotAgentsDB, body: SummaryParams, vo: str): } ) return await pilot_db.pilot_summary(body.grouping, body.search) + + +async def get_secrets_by_hashed_secrets( + pilot_db: PilotAgentsDB, hashed_secrets: list[bytes], parameters: list[str] = [] +) -> list[dict[Any, Any]]: + if parameters: + parameters.append("HashedSecret") + + _, secrets = await pilot_db.search_secrets( + parameters=parameters, + search=[ + VectorSearchSpec( + parameter="HashedSecret", + operator=VectorSearchOperator.IN, + values=hashed_secrets, + ) + ], + sorts=[], + distinct=True, + per_page=MAX_PER_PAGE, + ) + + # Custom handling, to see which hashed_secrets does not exist + found_keys = {row["HashedSecret"] for row in secrets} + missing = set(hashed_secrets) - found_keys + + if missing: + raise SecretNotFoundError(str(missing)) + + return secrets + + +async def get_secrets_by_uuid( + pilot_db: PilotAgentsDB, secret_uuids: list[str], parameters: list[str] = [] +) -> list[dict[Any, Any]]: + if parameters: + parameters.append("SecretUUID") # To avoid bug later on `found_keys = ...` + + _, secrets = await pilot_db.search_secrets( + parameters=parameters, + search=[ + VectorSearchSpec( + parameter="SecretUUID", + operator=VectorSearchOperator.IN, + values=secret_uuids, + ) + ], + sorts=[], + distinct=True, + per_page=MAX_PER_PAGE, + ) + + # Custom handling, to see which secret_uuid does not exist + # TODO: Add missing in the error + found_keys = {row["SecretUUID"] for row in secrets} + missing = set(secret_uuids) - found_keys + + if missing: + raise SecretNotFoundError(detail=str(missing)) + + return secrets diff --git a/diracx-routers/pyproject.toml b/diracx-routers/pyproject.toml index 2038223ce..744788723 100644 --- a/diracx-routers/pyproject.toml +++ b/diracx-routers/pyproject.toml @@ -47,6 +47,7 @@ config = "diracx.routers.configuration:router" health = "diracx.routers.health:router" jobs = "diracx.routers.jobs:router" pilots = "diracx.routers.pilots:router" +"pilots/internal" = "diracx.routers.pilot_resources:router" [project.entry-points."diracx.access_policies"] WMSAccessPolicy = "diracx.routers.jobs.access_policies:WMSAccessPolicy" diff --git a/diracx-routers/src/diracx/routers/auth/__init__.py b/diracx-routers/src/diracx/routers/auth/__init__.py index ed71900f2..f11ccd3ac 100644 --- a/diracx-routers/src/diracx/routers/auth/__init__.py +++ b/diracx-routers/src/diracx/routers/auth/__init__.py @@ -5,6 +5,7 @@ from .authorize_code_flow import router as authorize_code_flow_router from .device_flow import router as device_flow_router from .management import router as management_router +from .pilots import router as pilot_router from .token import router as token_router from .utils import has_properties @@ -13,5 +14,6 @@ router.include_router(management_router) router.include_router(authorize_code_flow_router) router.include_router(token_router) +router.include_router(pilot_router) __all__ = ["has_properties", "verify_dirac_access_token"] diff --git a/diracx-routers/src/diracx/routers/auth/pilots.py b/diracx-routers/src/diracx/routers/auth/pilots.py new file mode 100644 index 000000000..e5ed668f1 --- /dev/null +++ b/diracx-routers/src/diracx/routers/auth/pilots.py @@ -0,0 +1,107 @@ +"""Token endpoint.""" + +from __future__ import annotations + +from typing import Annotated + +from fastapi import Body, HTTPException, status + +from diracx.core.exceptions import ( + BadPilotCredentialsError, + BadTokenError, + InvalidCredentialsError, + PilotNotFoundError, + SecretHasExpiredError, + SecretNotFoundError, +) +from diracx.core.models import ( + PilotAuthCredentials, + TokenResponse, + VacuumPilotAuth, +) +from diracx.logic.pilots.auth import refresh_pilot_token, verify_pilot_credentials + +from ..dependencies import ( + AuthDB, + AuthSettings, + PilotAgentsDB, +) +from ..fastapi_classes import DiracxRouter + +router = DiracxRouter(require_auth=False) + + +@router.post("/secret-exchange") +async def perform_secret_exchange( + pilot_db: PilotAgentsDB, + auth_db: AuthDB, + pilot_credentials: Annotated[ + PilotAuthCredentials | VacuumPilotAuth, + Body(description="Pilot credentials (stamp and secret)"), + ], + settings: AuthSettings, +) -> TokenResponse: + """This endpoint is used by the pilot to exchange a secret for a token. + + This endpoint also acts as DIRAC's `dirac-admin-add-pilot`. + """ + try: + return await verify_pilot_credentials( + pilot_db=pilot_db, + auth_db=auth_db, + credentials=pilot_credentials, + settings=settings, + ) + except BadPilotCredentialsError as e: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="bad credentials" + ) from e + except PilotNotFoundError as e: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="bad pilot_stamp", + ) from e + except SecretNotFoundError as e: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="bad pilot_secret", + ) from e + except SecretHasExpiredError as e: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="secret expired", + ) from e + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=str(e) + ) from e + + +@router.post("/pilot-token") +async def refresh_pilot_tokens( + auth_db: AuthDB, + settings: AuthSettings, + pilot_db: PilotAgentsDB, + refresh_token: Annotated[ + str, Body(description="Refresh Token given at login by DiracX.") + ], + pilot_stamp: Annotated[str, Body(description="Pilot stamp")], +) -> TokenResponse: + """Endpoint where *only* pilots can exchange a refresh token for a token.""" + # Refresh it + try: + return await refresh_pilot_token( + pilot_stamp=pilot_stamp, + auth_db=auth_db, + settings=settings, + pilot_db=pilot_db, + refresh_token=refresh_token, + ) + except (InvalidCredentialsError, PilotNotFoundError, BadTokenError) as e: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e) + ) from e + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=str(e) + ) from e diff --git a/diracx-routers/src/diracx/routers/auth/token.py b/diracx-routers/src/diracx/routers/auth/token.py index 0a71308ea..4f809e069 100644 --- a/diracx-routers/src/diracx/routers/auth/token.py +++ b/diracx-routers/src/diracx/routers/auth/token.py @@ -10,6 +10,7 @@ from joserfc.errors import JoseError from diracx.core.exceptions import ( + BadTokenError, DiracHttpResponseError, InvalidCredentialsError, PendingAuthorizationError, @@ -153,7 +154,7 @@ async def get_oidc_token( detail=str(e), headers={"WWW-Authenticate": "Bearer"}, ) from e - except PermissionError as e: + except (BadTokenError, PermissionError) as e: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=str(e), diff --git a/diracx-routers/src/diracx/routers/pilot_resources/__init__.py b/diracx-routers/src/diracx/routers/pilot_resources/__init__.py new file mode 100644 index 000000000..f6c2f551e --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilot_resources/__init__.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import logging + +from fastapi import Depends + +from diracx.routers.utils.pilots import ( + verify_dirac_pilot_access_token, +) + +from ..fastapi_classes import DiracxRouter +from .util import router as util_router + +logger = logging.getLogger(__name__) + + +# Require_auth set to False because it adds *user* auth, and not pilot's +# So we add it manually +router = DiracxRouter( + require_auth=False, dependencies=[Depends(verify_dirac_pilot_access_token)] +) + +router.include_router(util_router) diff --git a/diracx-routers/src/diracx/routers/pilot_resources/util.py b/diracx-routers/src/diracx/routers/pilot_resources/util.py new file mode 100644 index 000000000..c0c965e22 --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilot_resources/util.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import logging +from typing import Annotated + +from fastapi import Depends + +from diracx.core.models import ( + PilotInfo, +) +from diracx.routers.utils.pilots import ( + AuthorizedPilotInfo, + verify_dirac_pilot_access_token, +) + +from ..fastapi_classes import DiracxRouter + +router = DiracxRouter(require_auth=False) + +logger = logging.getLogger(__name__) + + +@router.get("/pilotinfo") +async def userinfo( + pilot_info: Annotated[ + AuthorizedPilotInfo, Depends(verify_dirac_pilot_access_token) + ], +) -> PilotInfo: + """Get information about the user's identity.""" + return PilotInfo( + sub=pilot_info.sub, vo=pilot_info.vo, pilot_stamp=pilot_info.pilot_stamp + ) diff --git a/diracx-routers/src/diracx/routers/pilots/__init__.py b/diracx-routers/src/diracx/routers/pilots/__init__.py index 03f9b8422..1c2f9d908 100644 --- a/diracx-routers/src/diracx/routers/pilots/__init__.py +++ b/diracx-routers/src/diracx/routers/pilots/__init__.py @@ -8,6 +8,6 @@ logger = logging.getLogger(__name__) -router = DiracxRouter() +router = DiracxRouter(require_auth=True) router.include_router(management_router) router.include_router(query_router) diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py index a383643d1..2304f166b 100644 --- a/diracx-routers/src/diracx/routers/pilots/management.py +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -7,12 +7,21 @@ from diracx.core.exceptions import ( PilotAlreadyExistsError, + PilotNotFoundError, + SecretNotFoundError, ) from diracx.core.models import ( + PilotCredentialsInfo, PilotFieldsMapping, + PilotSecretConstraints, + PilotSecretsInfo, PilotStatus, ) from diracx.core.properties import GENERIC_PILOT, JOB_ADMINISTRATOR +from diracx.logic.pilots.auth import create_secrets +from diracx.logic.pilots.auth import ( + update_secrets_constraints as update_secrets_constraints_bl, +) from diracx.logic.pilots.management import ( delete_pilots as delete_pilots_bl, ) @@ -24,7 +33,7 @@ from diracx.logic.pilots.query import get_pilot_ids_by_job_id from diracx.routers.utils.users import AuthorizedUserInfo, verify_dirac_access_token -from ..dependencies import JobDB, PilotAgentsDB +from ..dependencies import AuthSettings, JobDB, PilotAgentsDB from ..fastapi_classes import DiracxRouter from .access_policies import ( ActionType, @@ -42,6 +51,7 @@ async def add_pilot_stamps( Body(description="List of the pilot stamps we want to add to the db."), ], vo: Annotated[str, Body(description="Pilot virtual organization.")], + settings: AuthSettings, check_permissions: CheckPilotManagementPolicyCallable, user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], grid_type: Annotated[str, Body(description="Grid type of the pilots.")] = "Dirac", @@ -56,7 +66,15 @@ async def add_pilot_stamps( pilot_status: Annotated[ PilotStatus, Body(description="Status of the pilots.") ] = PilotStatus.SUBMITTED, -): + generate_secrets: Annotated[ + bool, + Body(description="If we want to create secrets with the pilots."), + ] = True, + pilot_secret_use_count_max: Annotated[ + int | None, + Body(description="How much time can a secret be used."), + ] = 1, +) -> list[PilotCredentialsInfo] | None: """Endpoint where a you can create pilots with their references. If a pilot stamp already exists, it will block the insertion. @@ -84,15 +102,18 @@ async def add_pilot_stamps( ) try: - await register_new_pilots( + return await register_new_pilots( pilot_db=pilot_db, pilot_stamps=pilot_stamps, - vo=vo, + vo=user_info.vo, grid_type=grid_type, grid_site=grid_site, destination_site=destination_site, pilot_job_references=pilot_references, status=pilot_status, + settings=settings, + generate_secrets=generate_secrets, + pilot_secret_use_count_max=pilot_secret_use_count_max, ) except PilotAlreadyExistsError as e: raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e @@ -164,6 +185,80 @@ async def delete_pilots( ) +@router.post("/secrets") +async def create_pilot_secrets( + n: Annotated[int, Body(description="Number of secrets to create.")], + expiration_minutes: Annotated[ + int | None, Body(description="Time in minutes before expiring.") + ], + pilot_secret_use_count_max: Annotated[ + int | None, Body(description="Number of times that we can use a secret.") + ], + vo: Annotated[str, Body(description="Only VO that can access a secret.")], + check_permissions: CheckPilotManagementPolicyCallable, + pilot_db: PilotAgentsDB, + settings: AuthSettings, +) -> list[PilotSecretsInfo]: + """Endpoint to create secrets.""" + await check_permissions(action=ActionType.MANAGE_PILOTS) + + if expiration_minutes and expiration_minutes <= 0: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="expiration_minutes must be strictly positive.", + ) + if pilot_secret_use_count_max and pilot_secret_use_count_max <= 0: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="pilot_secret_use_count_max is either None or a positive number.", + ) + + return await create_secrets( + n=n, + pilot_db=pilot_db, + settings=settings, + secret_constraint=PilotSecretConstraints(VOs=[vo]), + pilot_secret_use_count_max=pilot_secret_use_count_max, + expiration_minutes=expiration_minutes, + ) + + +@router.patch("/secrets", status_code=HTTPStatus.NO_CONTENT) +async def update_secrets_constraints( + secrets_to_constraints_dict: Annotated[ + dict[str, PilotSecretConstraints], + Body(description="Mapping between secrets and pilots.", embed=False), + ], + pilot_db: PilotAgentsDB, + check_permissions: CheckPilotManagementPolicyCallable, +): + """Endpoint to associate pilots with secrets.""" + pilot_stamps = set() + for constraints in secrets_to_constraints_dict.values(): + if "PilotStamps" in constraints: + pilot_stamps.update(constraints["PilotStamps"]) + + await check_permissions( + action=ActionType.MANAGE_PILOTS, + ) + + try: + await update_secrets_constraints_bl( + pilot_db=pilot_db, + secrets_to_constraints_dict=secrets_to_constraints_dict, + ) + except SecretNotFoundError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="one of the secrets does not exist", + ) from e + except PilotNotFoundError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="one of the pilots does not exist", + ) from e + + EXAMPLE_UPDATE_FIELDS = { "Update the BenchMark field": { "summary": "Update BenchMark", diff --git a/diracx-routers/src/diracx/routers/utils/pilots.py b/diracx-routers/src/diracx/routers/utils/pilots.py new file mode 100644 index 000000000..b341908ea --- /dev/null +++ b/diracx-routers/src/diracx/routers/utils/pilots.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import re +from typing import Annotated + +from fastapi import Header, HTTPException, status +from joserfc.errors import JoseError +from joserfc.jwt import JWTClaimsRegistry +from pydantic import BaseModel + +from diracx.core.models import UUID, PilotInfo +from diracx.logic.auth.utils import read_token +from diracx.routers.dependencies import AuthSettings + + +class AuthInfo(BaseModel): + # raw token for propagation + bearer_token: str + + # token ID in the DB for Component + # unique jwt identifier for user + token_id: UUID + + +class AuthorizedPilotInfo(AuthInfo, PilotInfo): + pass + + +async def verify_dirac_pilot_access_token( + settings: AuthSettings, + authorization: Annotated[str | None, Header()] = None, +) -> AuthorizedPilotInfo: + """Verify dirac pilot token and return a AuthorizedPilotInfo class + Used for each API endpoint. + """ + if not authorization: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authorization header is missing", + headers={"WWW-Authenticate": "Bearer"}, + ) + if match := re.fullmatch(r"Bearer (.+)", authorization): + raw_token = match.group(1) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid authorization header", + ) + + try: + claims = read_token( + raw_token, + settings.token_keystore.jwks, + settings.token_allowed_algorithms, + claims_requests=JWTClaimsRegistry( + iss={"essential": True, "value": settings.token_issuer}, + ), + ) + + return AuthorizedPilotInfo( + bearer_token=raw_token, + token_id=claims["jti"], + sub=claims["sub"], + pilot_stamp=claims["pilot_stamp"], + vo=claims["vo"], + ) + # We catch KeyError if a user tries with its token to access this resource: + # -> claims["pilot_stamp"] will lead to a KeyError + except (JoseError, KeyError) as e: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid JWT", + ) from e diff --git a/diracx-routers/src/diracx/routers/utils/users.py b/diracx-routers/src/diracx/routers/utils/users.py index 2446556c6..daf7ff249 100644 --- a/diracx-routers/src/diracx/routers/utils/users.py +++ b/diracx-routers/src/diracx/routers/utils/users.py @@ -1,18 +1,15 @@ from __future__ import annotations import re -import uuid as std_uuid from typing import Annotated, Any from fastapi import Depends, HTTPException, status from fastapi.security import OpenIdConnect from joserfc.errors import JoseError from joserfc.jwt import JWTClaimsRegistry -from pydantic import BaseModel, GetCoreSchemaHandler, GetJsonSchemaHandler -from pydantic_core import CoreSchema, core_schema -from uuid_utils import UUID as _UUID +from pydantic import BaseModel -from diracx.core.models import UserInfo +from diracx.core.models import UUID, UserInfo from diracx.core.properties import SecurityProperty from diracx.logic.auth.utils import read_token from diracx.routers.dependencies import AuthSettings @@ -27,29 +24,6 @@ ) -class UUID(_UUID): - """Subclass of uuid_utils.UUID to add pydantic support.""" - - @classmethod - def __get_pydantic_core_schema__( - cls, source_type: Any, handler: GetCoreSchemaHandler - ) -> CoreSchema: - """Use the stdlib uuid.UUID schema for validation and serialization.""" - std_schema = handler(std_uuid.UUID) - - def to_uuid_utils(u: std_uuid.UUID) -> UUID: - return cls(str(u)) - - return core_schema.no_info_after_validator_function(to_uuid_utils, std_schema) - - @classmethod - def __get_pydantic_json_schema__( - cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler - ) -> dict[str, Any]: - """Return the stdlib uuid.UUID schema for JSON serialization.""" - return handler(core_schema) - - class AuthInfo(BaseModel): # raw token for propagation bearer_token: str @@ -98,19 +72,21 @@ async def verify_dirac_access_token( iss={"essential": True, "value": settings.token_issuer}, ), ) - except JoseError as e: + + return AuthorizedUserInfo( + bearer_token=raw_token, + token_id=claims["jti"], + properties=claims["dirac_properties"], + sub=claims["sub"], + preferred_username=claims["preferred_username"], + dirac_group=claims["dirac_group"], + vo=claims["vo"], + policies=claims.get("dirac_policies", {}), + ) + except (JoseError, KeyError) as e: + # We catch KeyError to prevent pilots accessing user resources + # -> If a pilot tries, because he has not dirac_properties, KeyError will be raised raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid JWT", ) from e - - return AuthorizedUserInfo( - bearer_token=raw_token, - token_id=claims["jti"], - properties=claims["dirac_properties"], - sub=claims["sub"], - preferred_username=claims["preferred_username"], - dirac_group=claims["dirac_group"], - vo=claims["vo"], - policies=claims.get("dirac_policies", {}), - ) diff --git a/diracx-routers/tests/auth/test_standard.py b/diracx-routers/tests/auth/test_standard.py index ef03764f8..4889bf86d 100644 --- a/diracx-routers/tests/auth/test_standard.py +++ b/diracx-routers/tests/auth/test_standard.py @@ -814,6 +814,7 @@ async def test_keystore(test_client): "jti": "49ecc171-20be-5b88-0d65-26012c07f397", "exp": (datetime.now(tz=timezone.utc) + timedelta(hours=1)).timestamp(), "legacy_exchange": False, + "dirac_policies": {}, } # Generate keys diff --git a/diracx-routers/tests/pilots/test_pilot_auth.py b/diracx-routers/tests/pilots/test_pilot_auth.py new file mode 100644 index 000000000..618788b18 --- /dev/null +++ b/diracx-routers/tests/pilots/test_pilot_auth.py @@ -0,0 +1,456 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from time import sleep + +import pytest +from fastapi.testclient import TestClient +from pytest_httpx import HTTPXMock + +from diracx.core.models import PilotSecretConstraints +from diracx.db.sql.pilots.db import PilotAgentsDB +from diracx.db.sql.utils.functions import raw_hash +from diracx.logic.pilots.query import ( + get_pilots_by_stamp, + get_secrets_by_hashed_secrets, +) + +from ..auth.test_standard import _get_tokens, auth_httpx_mock # noqa: F401 + +pytestmark = pytest.mark.enabled_dependencies( + [ + "PilotCredentialsAccessPolicy", + "PilotManagementAccessPolicy", + "DevelopmentSettings", + "AuthDB", + "AuthSettings", + "ConfigSource", + "BaseAccessPolicy", + "PilotAgentsDB", + ] +) + +MAIN_VO = "lhcb" +DIRAC_CLIENT_ID = "myDIRACClientID" +N = 100 + + +@pytest.fixture +def normal_test_client(client_factory): + with client_factory.normal_user() as client: + yield client + + +@pytest.fixture +def diracx_pilot_client(client_factory): + with client_factory.diracx_pilot() as client: + yield client + + +@pytest.fixture +def non_mocked_hosts(normal_test_client) -> list[str]: + return [normal_test_client.base_url.host] + + +@pytest.fixture +async def add_stamps(normal_test_client): + db = normal_test_client.app.dependency_overrides[PilotAgentsDB.transaction].args[0] + + async with db as pilot_db: + # Add pilots + refs = [f"ref_{i}" for i in range(N)] + stamps = [f"stamp_{i}" for i in range(N)] + pilot_references = dict(zip(stamps, refs)) + + vo = MAIN_VO + + await pilot_db.add_pilots( + stamps, vo, grid_type="DIRAC", pilot_references=pilot_references + ) + + pilots = await get_pilots_by_stamp(db, stamps) + + return pilots + + +@pytest.fixture +async def add_secrets_and_time(normal_test_client, add_stamps, secret_duration_sec): + db = normal_test_client.app.dependency_overrides[PilotAgentsDB.transaction].args[0] + + async with db as pilot_db: + # Retrieve the stamps from the add_stamps fixture + stamps = [pilot["PilotStamp"] for pilot in add_stamps] + + # Add a VO restriction as well as association with a specific pilot + secrets = [f"AW0nd3rfulS3cr3t_{str(i)}" for i in range(len(stamps))] + hashed_secrets = [raw_hash(secret) for secret in secrets] + constraints = { + hashed_secret: PilotSecretConstraints(VOs=[MAIN_VO], PilotStamps=[stamp]) + for hashed_secret, stamp in zip(hashed_secrets, stamps) + } + + # Add creds + await pilot_db.insert_unique_secrets( + hashed_secrets=hashed_secrets, secret_constraints=constraints + ) + + # Associate with pilot + secrets_obj = await get_secrets_by_hashed_secrets(db, hashed_secrets) + + assert len(secrets_obj) == len(hashed_secrets) == len(stamps) + + # extract_timestamp_from_uuid7(secret_obj["SecretUUID"]) does not work here + # See #548 + expiration_date = [ + datetime.now(timezone.utc) + timedelta(seconds=secret_duration_sec) + for secret_obj in secrets_obj + ] + + await pilot_db.set_secret_expirations( + secret_uuids=[secret_obj["SecretUUID"] for secret_obj in secrets_obj], + pilot_secret_expiration_dates=expiration_date, + ) + + # Return both non-hashed secrets and stamps + return {"stamps": stamps, "secrets": secrets} + + +@pytest.mark.parametrize("secret_duration_sec", [10]) +async def test_verify_secret(normal_test_client, add_secrets_and_time): + # Add pilots + result = add_secrets_and_time + stamps = result["stamps"] + secrets = result["secrets"] + + pilot_stamp = stamps[0] + secret = secrets[0] + + # ----------------- Wrong password ----------------- + body = { + "pilot_stamp": pilot_stamp, + "pilot_secret": "My 1ncr3d1bl3 t0k3n", + } + + r = normal_test_client.post("/api/auth/secret-exchange", json=body) + + assert r.status_code == 401, r.json() + assert r.json()["detail"] == "bad pilot_secret" + + # ----------------- Good password ----------------- + + body = {"pilot_stamp": pilot_stamp, "pilot_secret": secret} + + r = normal_test_client.post("/api/auth/secret-exchange", json=body) + + assert r.status_code == 200, r.json() + + access_token = r.json()["access_token"] + refresh_token = r.json()["refresh_token"] + + assert access_token is not None + assert refresh_token is not None + + # ----------------- Wrong ID ----------------- + body = {"pilot_stamp": "It is a stamp", "pilot_secret": secret} + + r = normal_test_client.post( + "/api/auth/secret-exchange", + json=body, + ) + + assert r.status_code == 401 + assert r.json()["detail"] == "bad pilot_stamp" + + # ----------------- Exchange for new tokens ----------------- + body = {"refresh_token": refresh_token, "pilot_stamp": pilot_stamp} + r = normal_test_client.post( + "/api/auth/pilot-token", + json=body, + ) + + assert r.status_code == 200, r.json() + + new_access_token = r.json()["access_token"] + new_refresh_token = r.json()["refresh_token"] + + # ----------------- Exchange token with old token ----------------- + body = {"refresh_token": refresh_token, "pilot_stamp": pilot_stamp} + r = normal_test_client.post( + "/api/auth/pilot-token", + json=body, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert r.status_code == 401, r.json() + + # ----------------- Exchange token with new token ----------------- + body = {"refresh_token": new_refresh_token, "pilot_stamp": pilot_stamp} + r = normal_test_client.post( + "/api/auth/pilot-token", + json=body, + headers={"Authorization": f"Bearer {new_access_token}"}, + ) + + # RFC6749 + # https://datatracker.ietf.org/doc/html/rfc6749#section-10.4 + assert r.status_code == 401, r.json() + + # ----------------- Overused Secret ----------------- + body = {"pilot_stamp": pilot_stamp, "pilot_secret": secret} + + r = normal_test_client.post("/api/auth/secret-exchange", json=body) + + assert r.status_code == 401 + assert r.json()["detail"] == "bad pilot_secret" + + +@pytest.mark.parametrize("secret_duration_sec", [10]) +async def test_vacuum_case(normal_test_client, add_secrets_and_time): + result = add_secrets_and_time + secrets = result["secrets"] + + pilot_stamp = "this_might_be_totally_unknown" + secret = secrets[0] + + # ----------------- Good password but unknown stamp ----------------- + + body = {"pilot_stamp": pilot_stamp, "pilot_secret": secret} + + r = normal_test_client.post("/api/auth/secret-exchange", json=body) + + assert r.status_code == 401 + assert r.json()["detail"] == "bad pilot_stamp" + + # ----------------- Good password and vacuum case but wrong stamp for the secret ----------------- + # add_secrets_and_time associates secret_n with stamp_n. + # Because our pilot_stamp does not have a secret associated to it + # (or at least one where it can't access), we have to create an "opened" secret (for every stamp) + # This will be done in the next section + + body = { + "pilot_stamp": pilot_stamp, + "pilot_secret": secret, + "vo": MAIN_VO, + "grid_type": "test", + "grid_site": "test", + "status": "Waiting", + } + + r = normal_test_client.post("/api/auth/secret-exchange", json=body) + + assert r.status_code == 401, r.json() + assert r.json()["detail"] == "bad credentials" + + # ----------------- Add secret without restricting it to a certain stamp ----------------- + body = { + "n": 1, + "vo": MAIN_VO, + "expiration_minutes": 1, + "pilot_secret_use_count_max": 1, + } + + r = normal_test_client.post( + "/api/pilots/secrets", + json=body, + headers={"Content-Type": "application/json"}, + ) + assert r.status_code == 200, r.json() + + # Format : {"pilot_secret": "...", "pilot_secret_expires_in": ..., "pilot_stamps": None} + secrets_mapping = r.json() + secrets = [el["pilot_secret"] for el in secrets_mapping] + + assert len(secrets) == 1 + + secret = secrets[0] + + # ----------------- Good password and vacuum case ----------------- + + body = { + "pilot_stamp": pilot_stamp, + "pilot_secret": secret, + "vo": MAIN_VO, + "grid_type": "test", + "grid_site": "test", + "status": "Waiting", + } + + r = normal_test_client.post("/api/auth/secret-exchange", json=body) + + assert r.status_code == 200, r.json() + + access_token = r.json()["access_token"] + refresh_token = r.json()["refresh_token"] + + assert access_token is not None + assert refresh_token is not None + + # Get a pilot token, and try to access a pilot endpoint + r = normal_test_client.get( + "/api/pilots/internal/pilotinfo", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert r.status_code == 200 + + +@pytest.mark.parametrize("secret_duration_sec", [2]) +async def test_expired_secret(normal_test_client, add_secrets_and_time): + # Add pilots + result = add_secrets_and_time + stamps = result["stamps"] + secrets = result["secrets"] + + pilot_stamp = stamps[0] + secret = secrets[0] + + # ----------------- Secret that expired ----------------- + sleep(2) + + body = {"pilot_stamp": pilot_stamp, "pilot_secret": secret} + + r = normal_test_client.post("/api/auth/secret-exchange", json=body) + + assert r.status_code == 401 + assert r.json()["detail"] == "secret expired" + + # ----------------- Secret that expired, but reused ----------------- + # Should be deleted by the verify_pilot_secret + + body = {"pilot_stamp": pilot_stamp, "pilot_secret": secret} + + r = normal_test_client.post("/api/auth/secret-exchange", json=body) + + assert r.status_code == 401 + assert r.json()["detail"] == "bad pilot_secret" + + +@pytest.mark.parametrize("secret_duration_sec", [10]) +async def test_access_user_info_with_pilot_token( + normal_test_client, add_secrets_and_time +): + # ----------------- Access user info but with a pilot token ----------------- + # Add pilots + result = add_secrets_and_time + stamps = result["stamps"] + secrets = result["secrets"] + + pilot_stamp = stamps[0] + secret = secrets[0] + body = {"pilot_stamp": pilot_stamp, "pilot_secret": secret} + r = normal_test_client.post("/api/auth/secret-exchange", json=body) + + assert r.status_code == 200, r.json() + + access_token = r.json()["access_token"] + refresh_token = r.json()["refresh_token"] + + assert access_token is not None + assert refresh_token is not None + + # Get a pilot token, and try to access a user endpoint + r = normal_test_client.get( + "/api/auth/userinfo", headers={"Authorization": f"Bearer {access_token}"} + ) + + assert r.status_code == 401 + + # Get a pilot token, and try to access a pilot endpoint + r = normal_test_client.get( + "/api/pilots/internal/pilotinfo", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert r.status_code == 200 + + +async def test_refresh_pilot_token_with_user_token( + normal_test_client: TestClient, + auth_httpx_mock: HTTPXMock, # noqa: F811 +): + # ----------------- Exchange for new tokens but with a user token ----------------- + # This will test that a user can't access a pilot endpoint *by default* + access_token = normal_test_client.headers["Authorization"] + + refresh_token = _get_tokens(normal_test_client)["refresh_token"] + + assert access_token + assert refresh_token + + # ----------------- First, with a pilot that does not exist ----------------- + body = {"refresh_token": refresh_token, "pilot_stamp": "stamp_0"} + r = normal_test_client.post( + "/api/auth/pilot-token", + json=body, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert r.status_code == 401 + assert "not found" in r.json()["detail"] + + # ----------------- Then, with a pilot that does exist ----------------- + # First, we need to create this pilot + + pilot_stamp = "stamp_1" + body = {"vo": MAIN_VO, "pilot_stamps": [pilot_stamp]} + + # Create a pilot + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 200, r.json() + + body = {"refresh_token": refresh_token, "pilot_stamp": "stamp_1"} + r = normal_test_client.post( + "/api/auth/pilot-token", + json=body, + ) + + assert r.status_code == 401 + assert r.json()["detail"] == "This is not a pilot token." + + +async def test_get_pilot_info_with_user_token( + normal_test_client: TestClient, +): + r = normal_test_client.get( + "/api/pilots/internal/pilotinfo", + ) + + assert r.status_code == 401 + + +@pytest.mark.parametrize("secret_duration_sec", [10]) +async def test_refresh_user_token_with_pilot_token( + normal_test_client, add_secrets_and_time +): + # Add pilots + result = add_secrets_and_time + stamps = result["stamps"] + secrets = result["secrets"] + + pilot_stamp = stamps[0] + secret = secrets[0] + + # ----------------- Good password ----------------- + + body = {"pilot_stamp": pilot_stamp, "pilot_secret": secret} + + r = normal_test_client.post("/api/auth/secret-exchange", json=body) + + assert r.status_code == 200, r.json() + + refresh_token = r.json()["refresh_token"] + + request_data = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": DIRAC_CLIENT_ID, + } + + r = normal_test_client.post("/api/auth/token", data=request_data) + + assert r.status_code == 403, r.json() + assert r.json()["detail"] == "This is not a user token." diff --git a/diracx-routers/tests/pilots/test_pilot_creation.py b/diracx-routers/tests/pilots/test_pilot_creation.py index c055727c9..064517bda 100644 --- a/diracx-routers/tests/pilots/test_pilot_creation.py +++ b/diracx-routers/tests/pilots/test_pilot_creation.py @@ -50,6 +50,21 @@ async def test_create_pilots(normal_test_client): assert r.status_code == 200, r.json() + # -------------- Logins -------------- + pilot_credentials_list = r.json() + for credentials in pilot_credentials_list: + pilot_stamp, secret = (credentials["pilot_stamp"], credentials["pilot_secret"]) + + body = {"pilot_stamp": pilot_stamp, "pilot_secret": secret} + + r = normal_test_client.post( + "/api/auth/secret-exchange", + json=body, + headers={"Content-Type": "application/json"}, + ) + + assert r.status_code == 200, r.json() + # -------------- Register a pilot that already exists, and one that does not -------------- body = { @@ -85,6 +100,156 @@ async def test_create_pilots(normal_test_client): ) assert r.status_code == 200 + secret = r.json()[0]["pilot_secret"] + + # -------------- Login with a pilot that does not exist **but** was called before in an error -------------- + + body = { + "pilot_stamp": pilot_stamps[0] + "_new_one", + "pilot_secret": secret, + } + + r = normal_test_client.post( + "/api/auth/secret-exchange", + json=body, + headers={"Content-Type": "application/json"}, + ) + + assert r.status_code == 200, r.json() + + # -------------- Login with a pilot credentials of another pilot -------------- + + body = { + "pilot_stamp": pilot_stamps[0] + "_new_one", + "pilot_secret": pilot_credentials_list[0][ + "pilot_secret" + ], # [0] = first pilot from the list before, [1] = the secret + } + + r = normal_test_client.post( + "/api/auth/secret-exchange", + json=body, + headers={"Content-Type": "application/json"}, + ) + + assert r.status_code == 401, r.json() + assert r.json()["detail"] == "bad pilot_secret" + + +async def test_create_secrets_and_login(normal_test_client): + pilot_stamps = [f"stamps_{i}" for i in range(N)] + + # -------------- Create N secrets. -------------- + + body = { + "n": N, + "vo": MAIN_VO, + "expiration_minutes": 1, + "pilot_secret_use_count_max": 2 * N, # Used later + } + + r = normal_test_client.post( + "/api/pilots/secrets", + json=body, + headers={"Content-Type": "application/json"}, + ) + + assert r.status_code == 200, r.json() + + # Format : {"pilot_secret": "...", "pilot_secret_expires_in": ..., "pilot_stamps": None} + secrets_mapping = r.json() + + secrets = [el["pilot_secret"] for el in secrets_mapping] + + assert len(secrets) == N + + # -------------- Create pilot *without* secrets -------------- + + body = {"vo": MAIN_VO, "pilot_stamps": pilot_stamps, "generate_secrets": False} + + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 200, r.json() + + # -------------- Associate pilot with bad secrets -------------- + + body = {"bad_secret": {"PilotStamps": pilot_stamps}} + + r = normal_test_client.patch( + "/api/pilots/secrets", + json=body, + headers={"Content-Type": "application/json"}, + ) + + assert r.status_code == 400, r.json() + assert r.json()["detail"] == "one of the secrets does not exist" + + # -------------- Associate pilot with secrets -------------- + + body = { + pilot_secret: {"PilotStamps": [pilot_stamp]} + for pilot_secret, pilot_stamp in zip(secrets, pilot_stamps) + } + + r = normal_test_client.patch( + "/api/pilots/secrets", + json=body, + headers={"Content-Type": "application/json"}, + ) + + assert r.status_code == 204 + # -------------- Login with the right credentials -------------- + + for stamp, secret in zip(pilot_stamps, secrets): + body = {"pilot_secret": secret, "pilot_stamp": stamp} + + r = normal_test_client.post( + "/api/auth/secret-exchange", + json=body, + headers={"Content-Type": "application/json"}, + ) + + assert r.status_code == 200, r.json() + + # -------------- Login with the wrong credentials -------------- + + body = {"pilot_secret": secrets[1], "pilot_stamp": pilot_stamps[0]} + + r = normal_test_client.post( + "/api/auth/secret-exchange", + json=body, + headers={"Content-Type": "application/json"}, + ) + + assert r.status_code == 401, r.json() + + # -------------- Associate everyone to secrets[1] -------------- + + # Allowed by the router to avoid sending thousands of the same secret, if we want bunch of pilots to share a secret + body = {secrets[1]: {"PilotStamps": pilot_stamps}} + + r = normal_test_client.patch( + "/api/pilots/secrets", + json=body, + headers={"Content-Type": "application/json"}, + ) + + assert r.status_code == 204 + + # -------------- Login with the right credentials -------------- + for stamp in pilot_stamps: + body = {"pilot_secret": secrets[1], "pilot_stamp": stamp} + + r = normal_test_client.post( + "/api/auth/secret-exchange", + json=body, + headers={"Content-Type": "application/json"}, + ) + + assert r.status_code == 200, r.json() async def test_create_pilot_and_delete_it(normal_test_client): diff --git a/diracx-testing/src/diracx/testing/utils.py b/diracx-testing/src/diracx/testing/utils.py index 2991f4777..d37153a73 100644 --- a/diracx-testing/src/diracx/testing/utils.py +++ b/diracx-testing/src/diracx/testing/utils.py @@ -381,6 +381,27 @@ def admin_user(self): client.dirac_token_payload = payload yield client + @contextlib.contextmanager + def diracx_pilot(self): + """Pilot from DiracX: a specific token.""" + from diracx.routers.auth.token import create_token + + with self.unauthenticated() as client: + payload = { + "sub": "testingVO:yellow-sub", + "exp": datetime.now(tz=timezone.utc) + + timedelta(self.test_auth_settings.access_token_expire_minutes), + "iss": ISSUER, + "jti": str(uuid7()), + "pilot_stamp": "pilot_stamp", + "vo": "lhcb", + } + token = create_token(payload, self.test_auth_settings) + + client.headers["Authorization"] = f"Bearer {token}" + client.dirac_token_payload = payload + yield client + @pytest.fixture(scope="session") def session_client_factory( From 57eff5e4ac961a7f7ac004449eb69fa5a6a6c133 Mon Sep 17 00:00:00 2001 From: Chris Burr Date: Tue, 5 Aug 2025 12:11:40 +0200 Subject: [PATCH 06/11] chore: Regenerate client (post-fix) --- .../src/diracx/client/_generated/_client.py | 12 +- .../diracx/client/_generated/aio/_client.py | 12 +- .../_generated/aio/operations/__init__.py | 2 + .../_generated/aio/operations/_operations.py | 499 ++++++++++++++- .../client/_generated/models/__init__.py | 18 + .../client/_generated/models/_models.py | 354 +++++++++++ .../client/_generated/operations/__init__.py | 2 + .../_generated/operations/_operations.py | 575 +++++++++++++++++- .../src/gubbins/client/_generated/_client.py | 6 +- .../gubbins/client/_generated/aio/_client.py | 6 +- .../_generated/aio/operations/__init__.py | 2 + .../_generated/aio/operations/_operations.py | 499 ++++++++++++++- .../client/_generated/models/__init__.py | 18 + .../client/_generated/models/_models.py | 354 +++++++++++ .../client/_generated/operations/__init__.py | 2 + .../_generated/operations/_operations.py | 575 +++++++++++++++++- 16 files changed, 2888 insertions(+), 48 deletions(-) diff --git a/diracx-client/src/diracx/client/_generated/_client.py b/diracx-client/src/diracx/client/_generated/_client.py index 9e37d5081..2fec8ebd3 100644 --- a/diracx-client/src/diracx/client/_generated/_client.py +++ b/diracx-client/src/diracx/client/_generated/_client.py @@ -15,7 +15,14 @@ from . import models as _models from ._configuration import DiracConfiguration from ._utils.serialization import Deserializer, Serializer -from .operations import AuthOperations, ConfigOperations, JobsOperations, PilotsOperations, WellKnownOperations +from .operations import ( + AuthOperations, + ConfigOperations, + JobsOperations, + PilotsInternalOperations, + PilotsOperations, + WellKnownOperations, +) class Dirac: # pylint: disable=client-accepts-api-version-keyword @@ -31,6 +38,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype jobs: _generated.operations.JobsOperations :ivar pilots: PilotsOperations operations :vartype pilots: _generated.operations.PilotsOperations + :ivar pilots_internal: PilotsInternalOperations operations + :vartype pilots_internal: _generated.operations.PilotsInternalOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -68,6 +77,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots_internal = PilotsInternalOperations(self._client, self._config, self._serialize, self._deserialize) def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> HttpResponse: """Runs the network request through the client's chained policies. diff --git a/diracx-client/src/diracx/client/_generated/aio/_client.py b/diracx-client/src/diracx/client/_generated/aio/_client.py index 397b7f989..774ed256d 100644 --- a/diracx-client/src/diracx/client/_generated/aio/_client.py +++ b/diracx-client/src/diracx/client/_generated/aio/_client.py @@ -15,7 +15,14 @@ from .. import models as _models from .._utils.serialization import Deserializer, Serializer from ._configuration import DiracConfiguration -from .operations import AuthOperations, ConfigOperations, JobsOperations, PilotsOperations, WellKnownOperations +from .operations import ( + AuthOperations, + ConfigOperations, + JobsOperations, + PilotsInternalOperations, + PilotsOperations, + WellKnownOperations, +) class Dirac: # pylint: disable=client-accepts-api-version-keyword @@ -31,6 +38,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype jobs: _generated.aio.operations.JobsOperations :ivar pilots: PilotsOperations operations :vartype pilots: _generated.aio.operations.PilotsOperations + :ivar pilots_internal: PilotsInternalOperations operations + :vartype pilots_internal: _generated.aio.operations.PilotsInternalOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -68,6 +77,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots_internal = PilotsInternalOperations(self._client, self._config, self._serialize, self._deserialize) def send_request( self, request: HttpRequest, *, stream: bool = False, **kwargs: Any diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py b/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py index be02776fc..1d17089e1 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py @@ -15,6 +15,7 @@ from ._operations import ConfigOperations # type: ignore from ._operations import JobsOperations # type: ignore from ._operations import PilotsOperations # type: ignore +from ._operations import PilotsInternalOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -26,6 +27,7 @@ "ConfigOperations", "JobsOperations", "PilotsOperations", + "PilotsInternalOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py index 82c27ee1b..268176e8d 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py @@ -34,6 +34,8 @@ build_auth_get_refresh_tokens_request, build_auth_initiate_authorization_flow_request, build_auth_initiate_device_flow_request, + build_auth_perform_secret_exchange_request, + build_auth_refresh_pilot_tokens_request, build_auth_revoke_refresh_token_by_jti_request, build_auth_revoke_refresh_token_by_refresh_token_request, build_auth_userinfo_request, @@ -53,11 +55,14 @@ build_jobs_unassign_bulk_jobs_sandboxes_request, build_jobs_unassign_job_sandboxes_request, build_pilots_add_pilot_stamps_request, + build_pilots_create_pilot_secrets_request, build_pilots_delete_pilots_request, build_pilots_get_pilot_jobs_request, + build_pilots_internal_userinfo_request, build_pilots_search_request, build_pilots_summary_request, build_pilots_update_pilot_fields_request, + build_pilots_update_secrets_constraints_request, build_well_known_get_installation_metadata_request, build_well_known_get_jwks_request, build_well_known_get_openid_configuration_request, @@ -903,6 +908,210 @@ async def complete_authorization_flow(self, *, code: str, state: str, **kwargs: return deserialized # type: ignore + @overload + async def perform_secret_exchange( + self, body: _models.PilotCredentials, *, content_type: str = "application/json", **kwargs: Any + ) -> _models.TokenResponse: + """Perform Secret Exchange. + + This endpoint is used by the pilot to exchange a secret for a token. + + This endpoint also acts as DIRAC's ``dirac-admin-add-pilot``. + + :param body: Required. + :type body: ~_generated.models.PilotCredentials + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: TokenResponse + :rtype: ~_generated.models.TokenResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def perform_secret_exchange( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> _models.TokenResponse: + """Perform Secret Exchange. + + This endpoint is used by the pilot to exchange a secret for a token. + + This endpoint also acts as DIRAC's ``dirac-admin-add-pilot``. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: TokenResponse + :rtype: ~_generated.models.TokenResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def perform_secret_exchange( + self, body: Union[_models.PilotCredentials, IO[bytes]], **kwargs: Any + ) -> _models.TokenResponse: + """Perform Secret Exchange. + + This endpoint is used by the pilot to exchange a secret for a token. + + This endpoint also acts as DIRAC's ``dirac-admin-add-pilot``. + + :param body: Is either a PilotCredentials type or a IO[bytes] type. Required. + :type body: ~_generated.models.PilotCredentials or IO[bytes] + :return: TokenResponse + :rtype: ~_generated.models.TokenResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[_models.TokenResponse] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "PilotCredentials") + + _request = build_auth_perform_secret_exchange_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("TokenResponse", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @overload + async def refresh_pilot_tokens( + self, body: _models.BodyAuthRefreshPilotTokens, *, content_type: str = "application/json", **kwargs: Any + ) -> _models.TokenResponse: + """Refresh Pilot Tokens. + + Endpoint where *only* pilots can exchange a refresh token for a token. + + :param body: Required. + :type body: ~_generated.models.BodyAuthRefreshPilotTokens + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: TokenResponse + :rtype: ~_generated.models.TokenResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def refresh_pilot_tokens( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> _models.TokenResponse: + """Refresh Pilot Tokens. + + Endpoint where *only* pilots can exchange a refresh token for a token. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: TokenResponse + :rtype: ~_generated.models.TokenResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def refresh_pilot_tokens( + self, body: Union[_models.BodyAuthRefreshPilotTokens, IO[bytes]], **kwargs: Any + ) -> _models.TokenResponse: + """Refresh Pilot Tokens. + + Endpoint where *only* pilots can exchange a refresh token for a token. + + :param body: Is either a BodyAuthRefreshPilotTokens type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyAuthRefreshPilotTokens or IO[bytes] + :return: TokenResponse + :rtype: ~_generated.models.TokenResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[_models.TokenResponse] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyAuthRefreshPilotTokens") + + _request = build_auth_refresh_pilot_tokens_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("TokenResponse", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + class ConfigOperations: """ @@ -2386,7 +2595,7 @@ def __init__(self, *args, **kwargs) -> None: @overload async def add_pilot_stamps( self, body: _models.BodyPilotsAddPilotStamps, *, content_type: str = "application/json", **kwargs: Any - ) -> Any: + ) -> Optional[List[_models.PilotCredentialsInfo]]: """Add Pilot Stamps. Endpoint where a you can create pilots with their references. @@ -2398,13 +2607,15 @@ async def add_pilot_stamps( :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. Default value is "application/json". :paramtype content_type: str - :return: any - :rtype: any + :return: list of PilotCredentialsInfo or None + :rtype: list[~_generated.models.PilotCredentialsInfo] or None :raises ~azure.core.exceptions.HttpResponseError: """ @overload - async def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + async def add_pilot_stamps( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> Optional[List[_models.PilotCredentialsInfo]]: """Add Pilot Stamps. Endpoint where a you can create pilots with their references. @@ -2416,13 +2627,15 @@ async def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "applic :keyword content_type: Body Parameter content-type. Content type parameter for binary body. Default value is "application/json". :paramtype content_type: str - :return: any - :rtype: any + :return: list of PilotCredentialsInfo or None + :rtype: list[~_generated.models.PilotCredentialsInfo] or None :raises ~azure.core.exceptions.HttpResponseError: """ @distributed_trace_async - async def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any) -> Any: + async def add_pilot_stamps( + self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any + ) -> Optional[List[_models.PilotCredentialsInfo]]: """Add Pilot Stamps. Endpoint where a you can create pilots with their references. @@ -2431,8 +2644,8 @@ async def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, I :param body: Is either a BodyPilotsAddPilotStamps type or a IO[bytes] type. Required. :type body: ~_generated.models.BodyPilotsAddPilotStamps or IO[bytes] - :return: any - :rtype: any + :return: list of PilotCredentialsInfo or None + :rtype: list[~_generated.models.PilotCredentialsInfo] or None :raises ~azure.core.exceptions.HttpResponseError: """ error_map: MutableMapping = { @@ -2447,7 +2660,7 @@ async def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, I _params = kwargs.pop("params", {}) or {} content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Any] = kwargs.pop("cls", None) + cls: ClsType[Optional[List[_models.PilotCredentialsInfo]]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -2477,7 +2690,7 @@ async def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, I map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) - deserialized = self._deserialize("object", pipeline_response.http_response) + deserialized = self._deserialize("[PilotCredentialsInfo]", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -2555,6 +2768,200 @@ async def delete_pilots( if cls: return cls(pipeline_response, None, {}) # type: ignore + @overload + async def create_pilot_secrets( + self, body: _models.BodyPilotsCreatePilotSecrets, *, content_type: str = "application/json", **kwargs: Any + ) -> List[_models.PilotSecretsInfo]: + """Create Pilot Secrets. + + Endpoint to create secrets. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsCreatePilotSecrets + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of PilotSecretsInfo + :rtype: list[~_generated.models.PilotSecretsInfo] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def create_pilot_secrets( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> List[_models.PilotSecretsInfo]: + """Create Pilot Secrets. + + Endpoint to create secrets. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of PilotSecretsInfo + :rtype: list[~_generated.models.PilotSecretsInfo] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def create_pilot_secrets( + self, body: Union[_models.BodyPilotsCreatePilotSecrets, IO[bytes]], **kwargs: Any + ) -> List[_models.PilotSecretsInfo]: + """Create Pilot Secrets. + + Endpoint to create secrets. + + :param body: Is either a BodyPilotsCreatePilotSecrets type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsCreatePilotSecrets or IO[bytes] + :return: list of PilotSecretsInfo + :rtype: list[~_generated.models.PilotSecretsInfo] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[_models.PilotSecretsInfo]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsCreatePilotSecrets") + + _request = build_pilots_create_pilot_secrets_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("[PilotSecretsInfo]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @overload + async def update_secrets_constraints( + self, body: Dict[str, _models.PilotSecretConstraints], *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Secrets Constraints. + + Endpoint to associate pilots with secrets. + + :param body: Required. + :type body: dict[str, ~_generated.models.PilotSecretConstraints] + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def update_secrets_constraints( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Secrets Constraints. + + Endpoint to associate pilots with secrets. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def update_secrets_constraints( + self, body: Union[Dict[str, _models.PilotSecretConstraints], IO[bytes]], **kwargs: Any + ) -> None: + """Update Secrets Constraints. + + Endpoint to associate pilots with secrets. + + :param body: Is either a {str: PilotSecretConstraints} type or a IO[bytes] type. Required. + :type body: dict[str, ~_generated.models.PilotSecretConstraints] or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "{PilotSecretConstraints}") + + _request = build_pilots_update_secrets_constraints_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + @overload async def update_pilot_fields( self, body: _models.BodyPilotsUpdatePilotFields, *, content_type: str = "application/json", **kwargs: Any @@ -2942,3 +3349,73 @@ async def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsInternalOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.aio.Dirac`'s + :attr:`pilots_internal` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: AsyncPipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @distributed_trace_async + async def userinfo(self, *, authorization: Optional[str] = None, **kwargs: Any) -> _models.PilotInfo: + """Userinfo. + + Get information about the user's identity. + + :keyword authorization: Default value is None. + :paramtype authorization: str + :return: PilotInfo + :rtype: ~_generated.models.PilotInfo + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.PilotInfo] = kwargs.pop("cls", None) + + _request = build_pilots_internal_userinfo_request( + authorization=authorization, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("PilotInfo", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/diracx-client/src/diracx/client/_generated/models/__init__.py b/diracx-client/src/diracx/client/_generated/models/__init__.py index ae52349c3..a37bc168a 100644 --- a/diracx-client/src/diracx/client/_generated/models/__init__.py +++ b/diracx-client/src/diracx/client/_generated/models/__init__.py @@ -16,7 +16,9 @@ BodyAuthGetOidcTokenGrantType, BodyJobsRescheduleJobs, BodyJobsUnassignBulkJobsSandboxes, + BodyAuthRefreshPilotTokens, BodyPilotsAddPilotStamps, + BodyPilotsCreatePilotSecrets, BodyPilotsUpdatePilotFields, GroupInfo, HTTPValidationError, @@ -29,7 +31,13 @@ JobStatusUpdate, Metadata, OpenIDConfiguration, + PilotAuthCredentials, + PilotCredentials, + PilotCredentialsInfo, PilotFieldsMapping, + PilotInfo, + PilotSecretConstraints, + PilotSecretsInfo, SandboxDownloadResponse, SandboxInfo, SandboxUploadResponse, @@ -46,6 +54,7 @@ TokenResponse, UserInfoResponse, VOInfo, + VacuumPilotAuth, ValidationError, ValidationErrorLocItem, VectorSearchSpec, @@ -71,7 +80,9 @@ "BodyAuthGetOidcTokenGrantType", "BodyJobsRescheduleJobs", "BodyJobsUnassignBulkJobsSandboxes", + "BodyAuthRefreshPilotTokens", "BodyPilotsAddPilotStamps", + "BodyPilotsCreatePilotSecrets", "BodyPilotsUpdatePilotFields", "GroupInfo", "HTTPValidationError", @@ -84,7 +95,13 @@ "JobStatusUpdate", "Metadata", "OpenIDConfiguration", + "PilotAuthCredentials", + "PilotCredentials", + "PilotCredentialsInfo", "PilotFieldsMapping", + "PilotInfo", + "PilotSecretConstraints", + "PilotSecretsInfo", "SandboxDownloadResponse", "SandboxInfo", "SandboxUploadResponse", @@ -101,6 +118,7 @@ "TokenResponse", "UserInfoResponse", "VOInfo", + "VacuumPilotAuth", "ValidationError", "ValidationErrorLocItem", "VectorSearchSpec", diff --git a/diracx-client/src/diracx/client/_generated/models/_models.py b/diracx-client/src/diracx/client/_generated/models/_models.py index 8763de15c..0211ac35a 100644 --- a/diracx-client/src/diracx/client/_generated/models/_models.py +++ b/diracx-client/src/diracx/client/_generated/models/_models.py @@ -144,6 +144,37 @@ def __init__(self, *, job_ids: List[int], **kwargs: Any) -> None: """ super().__init__(**kwargs) self.job_ids = job_ids +class BodyAuthRefreshPilotTokens(_serialization.Model): + """Body_auth_refresh_pilot_tokens. + + All required parameters must be populated in order to send to server. + + :ivar refresh_token: Refresh Token given at login by DiracX. Required. + :vartype refresh_token: str + :ivar pilot_stamp: Pilot stamp. Required. + :vartype pilot_stamp: str + """ + + _validation = { + "refresh_token": {"required": True}, + "pilot_stamp": {"required": True}, + } + + _attribute_map = { + "refresh_token": {"key": "refresh_token", "type": "str"}, + "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, + } + + def __init__(self, *, refresh_token: str, pilot_stamp: str, **kwargs: Any) -> None: + """ + :keyword refresh_token: Refresh Token given at login by DiracX. Required. + :paramtype refresh_token: str + :keyword pilot_stamp: Pilot stamp. Required. + :paramtype pilot_stamp: str + """ + super().__init__(**kwargs) + self.refresh_token = refresh_token + self.pilot_stamp = pilot_stamp class BodyPilotsAddPilotStamps(_serialization.Model): @@ -166,6 +197,10 @@ class BodyPilotsAddPilotStamps(_serialization.Model): :ivar pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", "Running", "Done", "Failed", "Deleted", "Aborted", and "Unknown". :vartype pilot_status: str or ~_generated.models.PilotStatus + :ivar generate_secrets: If we want to create secrets with the pilots. + :vartype generate_secrets: bool + :ivar pilot_secret_use_count_max: How much time can a secret be used. + :vartype pilot_secret_use_count_max: int """ _validation = { @@ -181,6 +216,8 @@ class BodyPilotsAddPilotStamps(_serialization.Model): "destination_site": {"key": "destination_site", "type": "str"}, "pilot_references": {"key": "pilot_references", "type": "{str}"}, "pilot_status": {"key": "pilot_status", "type": "str"}, + "generate_secrets": {"key": "generate_secrets", "type": "bool"}, + "pilot_secret_use_count_max": {"key": "pilot_secret_use_count_max", "type": "int"}, } def __init__( @@ -193,6 +230,8 @@ def __init__( destination_site: str = "NotAssigned", pilot_references: Optional[Dict[str, str]] = None, pilot_status: Optional[Union[str, "_models.PilotStatus"]] = None, + generate_secrets: bool = True, + pilot_secret_use_count_max: int = 1, **kwargs: Any ) -> None: """ @@ -211,6 +250,10 @@ def __init__( :keyword pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", "Running", "Done", "Failed", "Deleted", "Aborted", and "Unknown". :paramtype pilot_status: str or ~_generated.models.PilotStatus + :keyword generate_secrets: If we want to create secrets with the pilots. + :paramtype generate_secrets: bool + :keyword pilot_secret_use_count_max: How much time can a secret be used. + :paramtype pilot_secret_use_count_max: int """ super().__init__(**kwargs) self.pilot_stamps = pilot_stamps @@ -220,6 +263,57 @@ def __init__( self.destination_site = destination_site self.pilot_references = pilot_references self.pilot_status = pilot_status + self.generate_secrets = generate_secrets + self.pilot_secret_use_count_max = pilot_secret_use_count_max + + +class BodyPilotsCreatePilotSecrets(_serialization.Model): + """Body_pilots_create_pilot_secrets. + + All required parameters must be populated in order to send to server. + + :ivar n: Number of secrets to create. Required. + :vartype n: int + :ivar expiration_minutes: Time in minutes before expiring. Required. + :vartype expiration_minutes: int + :ivar pilot_secret_use_count_max: Number of times that we can use a secret. Required. + :vartype pilot_secret_use_count_max: int + :ivar vo: Only VO that can access a secret. Required. + :vartype vo: str + """ + + _validation = { + "n": {"required": True}, + "expiration_minutes": {"required": True}, + "pilot_secret_use_count_max": {"required": True}, + "vo": {"required": True}, + } + + _attribute_map = { + "n": {"key": "n", "type": "int"}, + "expiration_minutes": {"key": "expiration_minutes", "type": "int"}, + "pilot_secret_use_count_max": {"key": "pilot_secret_use_count_max", "type": "int"}, + "vo": {"key": "vo", "type": "str"}, + } + + def __init__( + self, *, n: int, expiration_minutes: int, pilot_secret_use_count_max: int, vo: str, **kwargs: Any + ) -> None: + """ + :keyword n: Number of secrets to create. Required. + :paramtype n: int + :keyword expiration_minutes: Time in minutes before expiring. Required. + :paramtype expiration_minutes: int + :keyword pilot_secret_use_count_max: Number of times that we can use a secret. Required. + :paramtype pilot_secret_use_count_max: int + :keyword vo: Only VO that can access a secret. Required. + :paramtype vo: str + """ + super().__init__(**kwargs) + self.n = n + self.expiration_minutes = expiration_minutes + self.pilot_secret_use_count_max = pilot_secret_use_count_max + self.vo = vo class BodyPilotsUpdatePilotFields(_serialization.Model): @@ -989,6 +1083,83 @@ def __init__( self.code_challenge_methods_supported = code_challenge_methods_supported +class PilotAuthCredentials(_serialization.Model): + """PilotAuthCredentials. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamp: Pilot Stamp. Required. + :vartype pilot_stamp: str + :ivar pilot_secret: Pilot Secret. Required. + :vartype pilot_secret: str + """ + + _validation = { + "pilot_stamp": {"required": True}, + "pilot_secret": {"required": True}, + } + + _attribute_map = { + "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, + "pilot_secret": {"key": "pilot_secret", "type": "str"}, + } + + def __init__(self, *, pilot_stamp: str, pilot_secret: str, **kwargs: Any) -> None: + """ + :keyword pilot_stamp: Pilot Stamp. Required. + :paramtype pilot_stamp: str + :keyword pilot_secret: Pilot Secret. Required. + :paramtype pilot_secret: str + """ + super().__init__(**kwargs) + self.pilot_stamp = pilot_stamp + self.pilot_secret = pilot_secret + + +class PilotCredentials(_serialization.Model): + """Pilot credentials (stamp and secret).""" + + +class PilotCredentialsInfo(_serialization.Model): + """PilotCredentialsInfo. + + All required parameters must be populated in order to send to server. + + :ivar pilot_secret: Pilot Secret. Required. + :vartype pilot_secret: str + :ivar pilot_secret_expires_in: Pilot Secret Expires In. Required. + :vartype pilot_secret_expires_in: int + :ivar pilot_stamp: Pilot Stamp. Required. + :vartype pilot_stamp: str + """ + + _validation = { + "pilot_secret": {"required": True}, + "pilot_secret_expires_in": {"required": True}, + "pilot_stamp": {"required": True}, + } + + _attribute_map = { + "pilot_secret": {"key": "pilot_secret", "type": "str"}, + "pilot_secret_expires_in": {"key": "pilot_secret_expires_in", "type": "int"}, + "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, + } + + def __init__(self, *, pilot_secret: str, pilot_secret_expires_in: int, pilot_stamp: str, **kwargs: Any) -> None: + """ + :keyword pilot_secret: Pilot Secret. Required. + :paramtype pilot_secret: str + :keyword pilot_secret_expires_in: Pilot Secret Expires In. Required. + :paramtype pilot_secret_expires_in: int + :keyword pilot_stamp: Pilot Stamp. Required. + :paramtype pilot_stamp: str + """ + super().__init__(**kwargs) + self.pilot_secret = pilot_secret + self.pilot_secret_expires_in = pilot_secret_expires_in + self.pilot_stamp = pilot_stamp + + class PilotFieldsMapping(_serialization.Model): """All the fields that a user can modify on a Pilot (except PilotStamp). @@ -1085,6 +1256,118 @@ def __init__( self.current_job_id = current_job_id +class PilotInfo(_serialization.Model): + """PilotInfo. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamp: Pilot Stamp. Required. + :vartype pilot_stamp: str + :ivar vo: Vo. Required. + :vartype vo: str + :ivar sub: Sub. Required. + :vartype sub: str + """ + + _validation = { + "pilot_stamp": {"required": True}, + "vo": {"required": True}, + "sub": {"required": True}, + } + + _attribute_map = { + "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, + "vo": {"key": "vo", "type": "str"}, + "sub": {"key": "sub", "type": "str"}, + } + + def __init__(self, *, pilot_stamp: str, vo: str, sub: str, **kwargs: Any) -> None: + """ + :keyword pilot_stamp: Pilot Stamp. Required. + :paramtype pilot_stamp: str + :keyword vo: Vo. Required. + :paramtype vo: str + :keyword sub: Sub. Required. + :paramtype sub: str + """ + super().__init__(**kwargs) + self.pilot_stamp = pilot_stamp + self.vo = vo + self.sub = sub + + +class PilotSecretConstraints(_serialization.Model): + """PilotSecretConstraints. + + :ivar v_os: Vos. + :vartype v_os: list[str] + :ivar pilot_stamps: Pilotstamps. + :vartype pilot_stamps: list[str] + :ivar sites: Sites. + :vartype sites: list[str] + """ + + _attribute_map = { + "v_os": {"key": "VOs", "type": "[str]"}, + "pilot_stamps": {"key": "PilotStamps", "type": "[str]"}, + "sites": {"key": "Sites", "type": "[str]"}, + } + + def __init__( + self, + *, + v_os: Optional[List[str]] = None, + pilot_stamps: Optional[List[str]] = None, + sites: Optional[List[str]] = None, + **kwargs: Any + ) -> None: + """ + :keyword v_os: Vos. + :paramtype v_os: list[str] + :keyword pilot_stamps: Pilotstamps. + :paramtype pilot_stamps: list[str] + :keyword sites: Sites. + :paramtype sites: list[str] + """ + super().__init__(**kwargs) + self.v_os = v_os + self.pilot_stamps = pilot_stamps + self.sites = sites + + +class PilotSecretsInfo(_serialization.Model): + """PilotSecretsInfo. + + All required parameters must be populated in order to send to server. + + :ivar pilot_secret: Pilot Secret. Required. + :vartype pilot_secret: str + :ivar pilot_secret_expires_in: Pilot Secret Expires In. Required. + :vartype pilot_secret_expires_in: int + """ + + _validation = { + "pilot_secret": {"required": True}, + "pilot_secret_expires_in": {"required": True}, + } + + _attribute_map = { + "pilot_secret": {"key": "pilot_secret", "type": "str"}, + "pilot_secret_expires_in": {"key": "pilot_secret_expires_in", "type": "int"}, + } + + def __init__(self, *, pilot_secret: str, pilot_secret_expires_in: int, **kwargs: Any) -> None: + """ + :keyword pilot_secret: Pilot Secret. Required. + :paramtype pilot_secret: str + :keyword pilot_secret_expires_in: Pilot Secret Expires In. Required. + :paramtype pilot_secret_expires_in: int + """ + super().__init__(**kwargs) + self.pilot_secret = pilot_secret + self.pilot_secret_expires_in = pilot_secret_expires_in + + class SandboxDownloadResponse(_serialization.Model): """SandboxDownloadResponse. @@ -1661,6 +1944,77 @@ def __init__( self.preferred_username = preferred_username +class VacuumPilotAuth(_serialization.Model): + """VacuumPilotAuth. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamp: Pilot Stamp. Required. + :vartype pilot_stamp: str + :ivar pilot_secret: Pilot Secret. Required. + :vartype pilot_secret: str + :ivar vo: Vo. Required. + :vartype vo: str + :ivar grid_type: Grid Type. Required. + :vartype grid_type: str + :ivar grid_site: Grid Site. Required. + :vartype grid_site: str + :ivar status: Status. Required. + :vartype status: str + """ + + _validation = { + "pilot_stamp": {"required": True}, + "pilot_secret": {"required": True}, + "vo": {"required": True}, + "grid_type": {"required": True}, + "grid_site": {"required": True}, + "status": {"required": True}, + } + + _attribute_map = { + "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, + "pilot_secret": {"key": "pilot_secret", "type": "str"}, + "vo": {"key": "vo", "type": "str"}, + "grid_type": {"key": "grid_type", "type": "str"}, + "grid_site": {"key": "grid_site", "type": "str"}, + "status": {"key": "status", "type": "str"}, + } + + def __init__( + self, + *, + pilot_stamp: str, + pilot_secret: str, + vo: str, + grid_type: str, + grid_site: str, + status: str, + **kwargs: Any + ) -> None: + """ + :keyword pilot_stamp: Pilot Stamp. Required. + :paramtype pilot_stamp: str + :keyword pilot_secret: Pilot Secret. Required. + :paramtype pilot_secret: str + :keyword vo: Vo. Required. + :paramtype vo: str + :keyword grid_type: Grid Type. Required. + :paramtype grid_type: str + :keyword grid_site: Grid Site. Required. + :paramtype grid_site: str + :keyword status: Status. Required. + :paramtype status: str + """ + super().__init__(**kwargs) + self.pilot_stamp = pilot_stamp + self.pilot_secret = pilot_secret + self.vo = vo + self.grid_type = grid_type + self.grid_site = grid_site + self.status = status + + class ValidationError(_serialization.Model): """ValidationError. diff --git a/diracx-client/src/diracx/client/_generated/operations/__init__.py b/diracx-client/src/diracx/client/_generated/operations/__init__.py index be02776fc..1d17089e1 100644 --- a/diracx-client/src/diracx/client/_generated/operations/__init__.py +++ b/diracx-client/src/diracx/client/_generated/operations/__init__.py @@ -15,6 +15,7 @@ from ._operations import ConfigOperations # type: ignore from ._operations import JobsOperations # type: ignore from ._operations import PilotsOperations # type: ignore +from ._operations import PilotsInternalOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -26,6 +27,7 @@ "ConfigOperations", "JobsOperations", "PilotsOperations", + "PilotsInternalOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/diracx-client/src/diracx/client/_generated/operations/_operations.py b/diracx-client/src/diracx/client/_generated/operations/_operations.py index c682d2a3a..55f846ee3 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/operations/_operations.py @@ -286,6 +286,40 @@ def build_auth_complete_authorization_flow_request( # pylint: disable=name-too- return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) +def build_auth_perform_secret_exchange_request(**kwargs: Any) -> HttpRequest: # pylint: disable=name-too-long + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/auth/secret-exchange" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + +def build_auth_refresh_pilot_tokens_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/auth/pilot-token" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + def build_config_serve_config_request( *, if_modified_since: Optional[str] = None, @@ -626,6 +660,37 @@ def build_pilots_delete_pilots_request( return HttpRequest(method="DELETE", url=_url, params=_params, **kwargs) +def build_pilots_create_pilot_secrets_request(**kwargs: Any) -> HttpRequest: # pylint: disable=name-too-long + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/secrets" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + +def build_pilots_update_secrets_constraints_request(**kwargs: Any) -> HttpRequest: # pylint: disable=name-too-long + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + # Construct URL + _url = "/api/pilots/secrets" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + + return HttpRequest(method="PATCH", url=_url, headers=_headers, **kwargs) + + def build_pilots_update_pilot_fields_request(**kwargs: Any) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) @@ -704,6 +769,22 @@ def build_pilots_summary_request(**kwargs: Any) -> HttpRequest: return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) +def build_pilots_internal_userinfo_request(*, authorization: Optional[str] = None, **kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/internal/pilotinfo" + + # Construct headers + if authorization is not None: + _headers["authorization"] = _SERIALIZER.header("authorization", authorization, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) + + class WellKnownOperations: """ .. warning:: @@ -1536,6 +1617,210 @@ def complete_authorization_flow(self, *, code: str, state: str, **kwargs: Any) - return deserialized # type: ignore + @overload + def perform_secret_exchange( + self, body: _models.PilotCredentials, *, content_type: str = "application/json", **kwargs: Any + ) -> _models.TokenResponse: + """Perform Secret Exchange. + + This endpoint is used by the pilot to exchange a secret for a token. + + This endpoint also acts as DIRAC's ``dirac-admin-add-pilot``. + + :param body: Required. + :type body: ~_generated.models.PilotCredentials + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: TokenResponse + :rtype: ~_generated.models.TokenResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def perform_secret_exchange( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> _models.TokenResponse: + """Perform Secret Exchange. + + This endpoint is used by the pilot to exchange a secret for a token. + + This endpoint also acts as DIRAC's ``dirac-admin-add-pilot``. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: TokenResponse + :rtype: ~_generated.models.TokenResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def perform_secret_exchange( + self, body: Union[_models.PilotCredentials, IO[bytes]], **kwargs: Any + ) -> _models.TokenResponse: + """Perform Secret Exchange. + + This endpoint is used by the pilot to exchange a secret for a token. + + This endpoint also acts as DIRAC's ``dirac-admin-add-pilot``. + + :param body: Is either a PilotCredentials type or a IO[bytes] type. Required. + :type body: ~_generated.models.PilotCredentials or IO[bytes] + :return: TokenResponse + :rtype: ~_generated.models.TokenResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[_models.TokenResponse] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "PilotCredentials") + + _request = build_auth_perform_secret_exchange_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("TokenResponse", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @overload + def refresh_pilot_tokens( + self, body: _models.BodyAuthRefreshPilotTokens, *, content_type: str = "application/json", **kwargs: Any + ) -> _models.TokenResponse: + """Refresh Pilot Tokens. + + Endpoint where *only* pilots can exchange a refresh token for a token. + + :param body: Required. + :type body: ~_generated.models.BodyAuthRefreshPilotTokens + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: TokenResponse + :rtype: ~_generated.models.TokenResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def refresh_pilot_tokens( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> _models.TokenResponse: + """Refresh Pilot Tokens. + + Endpoint where *only* pilots can exchange a refresh token for a token. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: TokenResponse + :rtype: ~_generated.models.TokenResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def refresh_pilot_tokens( + self, body: Union[_models.BodyAuthRefreshPilotTokens, IO[bytes]], **kwargs: Any + ) -> _models.TokenResponse: + """Refresh Pilot Tokens. + + Endpoint where *only* pilots can exchange a refresh token for a token. + + :param body: Is either a BodyAuthRefreshPilotTokens type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyAuthRefreshPilotTokens or IO[bytes] + :return: TokenResponse + :rtype: ~_generated.models.TokenResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[_models.TokenResponse] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyAuthRefreshPilotTokens") + + _request = build_auth_refresh_pilot_tokens_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("TokenResponse", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + class ConfigOperations: """ @@ -3017,7 +3302,7 @@ def __init__(self, *args, **kwargs) -> None: @overload def add_pilot_stamps( self, body: _models.BodyPilotsAddPilotStamps, *, content_type: str = "application/json", **kwargs: Any - ) -> Any: + ) -> Optional[List[_models.PilotCredentialsInfo]]: """Add Pilot Stamps. Endpoint where a you can create pilots with their references. @@ -3029,13 +3314,15 @@ def add_pilot_stamps( :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. Default value is "application/json". :paramtype content_type: str - :return: any - :rtype: any + :return: list of PilotCredentialsInfo or None + :rtype: list[~_generated.models.PilotCredentialsInfo] or None :raises ~azure.core.exceptions.HttpResponseError: """ @overload - def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + def add_pilot_stamps( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> Optional[List[_models.PilotCredentialsInfo]]: """Add Pilot Stamps. Endpoint where a you can create pilots with their references. @@ -3047,13 +3334,15 @@ def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/ :keyword content_type: Body Parameter content-type. Content type parameter for binary body. Default value is "application/json". :paramtype content_type: str - :return: any - :rtype: any + :return: list of PilotCredentialsInfo or None + :rtype: list[~_generated.models.PilotCredentialsInfo] or None :raises ~azure.core.exceptions.HttpResponseError: """ @distributed_trace - def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any) -> Any: + def add_pilot_stamps( + self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any + ) -> Optional[List[_models.PilotCredentialsInfo]]: """Add Pilot Stamps. Endpoint where a you can create pilots with their references. @@ -3062,8 +3351,8 @@ def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[byte :param body: Is either a BodyPilotsAddPilotStamps type or a IO[bytes] type. Required. :type body: ~_generated.models.BodyPilotsAddPilotStamps or IO[bytes] - :return: any - :rtype: any + :return: list of PilotCredentialsInfo or None + :rtype: list[~_generated.models.PilotCredentialsInfo] or None :raises ~azure.core.exceptions.HttpResponseError: """ error_map: MutableMapping = { @@ -3078,7 +3367,7 @@ def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[byte _params = kwargs.pop("params", {}) or {} content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Any] = kwargs.pop("cls", None) + cls: ClsType[Optional[List[_models.PilotCredentialsInfo]]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -3108,7 +3397,7 @@ def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[byte map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) - deserialized = self._deserialize("object", pipeline_response.http_response) + deserialized = self._deserialize("[PilotCredentialsInfo]", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -3186,6 +3475,200 @@ def delete_pilots( # pylint: disable=inconsistent-return-statements if cls: return cls(pipeline_response, None, {}) # type: ignore + @overload + def create_pilot_secrets( + self, body: _models.BodyPilotsCreatePilotSecrets, *, content_type: str = "application/json", **kwargs: Any + ) -> List[_models.PilotSecretsInfo]: + """Create Pilot Secrets. + + Endpoint to create secrets. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsCreatePilotSecrets + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of PilotSecretsInfo + :rtype: list[~_generated.models.PilotSecretsInfo] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def create_pilot_secrets( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> List[_models.PilotSecretsInfo]: + """Create Pilot Secrets. + + Endpoint to create secrets. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of PilotSecretsInfo + :rtype: list[~_generated.models.PilotSecretsInfo] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def create_pilot_secrets( + self, body: Union[_models.BodyPilotsCreatePilotSecrets, IO[bytes]], **kwargs: Any + ) -> List[_models.PilotSecretsInfo]: + """Create Pilot Secrets. + + Endpoint to create secrets. + + :param body: Is either a BodyPilotsCreatePilotSecrets type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsCreatePilotSecrets or IO[bytes] + :return: list of PilotSecretsInfo + :rtype: list[~_generated.models.PilotSecretsInfo] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[_models.PilotSecretsInfo]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsCreatePilotSecrets") + + _request = build_pilots_create_pilot_secrets_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("[PilotSecretsInfo]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @overload + def update_secrets_constraints( + self, body: Dict[str, _models.PilotSecretConstraints], *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Secrets Constraints. + + Endpoint to associate pilots with secrets. + + :param body: Required. + :type body: dict[str, ~_generated.models.PilotSecretConstraints] + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def update_secrets_constraints( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Secrets Constraints. + + Endpoint to associate pilots with secrets. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def update_secrets_constraints( # pylint: disable=inconsistent-return-statements + self, body: Union[Dict[str, _models.PilotSecretConstraints], IO[bytes]], **kwargs: Any + ) -> None: + """Update Secrets Constraints. + + Endpoint to associate pilots with secrets. + + :param body: Is either a {str: PilotSecretConstraints} type or a IO[bytes] type. Required. + :type body: dict[str, ~_generated.models.PilotSecretConstraints] or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "{PilotSecretConstraints}") + + _request = build_pilots_update_secrets_constraints_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + @overload def update_pilot_fields( self, body: _models.BodyPilotsUpdatePilotFields, *, content_type: str = "application/json", **kwargs: Any @@ -3569,3 +4052,73 @@ def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsInternalOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.Dirac`'s + :attr:`pilots_internal` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: PipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @distributed_trace + def userinfo(self, *, authorization: Optional[str] = None, **kwargs: Any) -> _models.PilotInfo: + """Userinfo. + + Get information about the user's identity. + + :keyword authorization: Default value is None. + :paramtype authorization: str + :return: PilotInfo + :rtype: ~_generated.models.PilotInfo + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.PilotInfo] = kwargs.pop("cls", None) + + _request = build_pilots_internal_userinfo_request( + authorization=authorization, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("PilotInfo", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py index fdf17b6a3..d6f7ee08b 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py @@ -20,12 +20,13 @@ ConfigOperations, JobsOperations, LollygagOperations, + PilotsInternalOperations, PilotsOperations, WellKnownOperations, ) -class Dirac: # pylint: disable=client-accepts-api-version-keyword +class Dirac: # pylint: disable=client-accepts-api-version-keyword,too-many-instance-attributes """Dirac. :ivar well_known: WellKnownOperations operations @@ -40,6 +41,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype lollygag: _generated.operations.LollygagOperations :ivar pilots: PilotsOperations operations :vartype pilots: _generated.operations.PilotsOperations + :ivar pilots_internal: PilotsInternalOperations operations + :vartype pilots_internal: _generated.operations.PilotsInternalOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -78,6 +81,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) self.lollygag = LollygagOperations(self._client, self._config, self._serialize, self._deserialize) self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots_internal = PilotsInternalOperations(self._client, self._config, self._serialize, self._deserialize) def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> HttpResponse: """Runs the network request through the client's chained policies. diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py index 76280797e..d0af9cbd9 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py @@ -20,12 +20,13 @@ ConfigOperations, JobsOperations, LollygagOperations, + PilotsInternalOperations, PilotsOperations, WellKnownOperations, ) -class Dirac: # pylint: disable=client-accepts-api-version-keyword +class Dirac: # pylint: disable=client-accepts-api-version-keyword,too-many-instance-attributes """Dirac. :ivar well_known: WellKnownOperations operations @@ -40,6 +41,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype lollygag: _generated.aio.operations.LollygagOperations :ivar pilots: PilotsOperations operations :vartype pilots: _generated.aio.operations.PilotsOperations + :ivar pilots_internal: PilotsInternalOperations operations + :vartype pilots_internal: _generated.aio.operations.PilotsInternalOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -78,6 +81,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) self.lollygag = LollygagOperations(self._client, self._config, self._serialize, self._deserialize) self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots_internal = PilotsInternalOperations(self._client, self._config, self._serialize, self._deserialize) def send_request( self, request: HttpRequest, *, stream: bool = False, **kwargs: Any diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py index 3408891fc..9eb133a2d 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py @@ -16,6 +16,7 @@ from ._operations import JobsOperations # type: ignore from ._operations import LollygagOperations # type: ignore from ._operations import PilotsOperations # type: ignore +from ._operations import PilotsInternalOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -28,6 +29,7 @@ "JobsOperations", "LollygagOperations", "PilotsOperations", + "PilotsInternalOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py index 19925b650..44c3f8572 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py @@ -34,6 +34,8 @@ build_auth_get_refresh_tokens_request, build_auth_initiate_authorization_flow_request, build_auth_initiate_device_flow_request, + build_auth_perform_secret_exchange_request, + build_auth_refresh_pilot_tokens_request, build_auth_revoke_refresh_token_by_jti_request, build_auth_revoke_refresh_token_by_refresh_token_request, build_auth_userinfo_request, @@ -56,11 +58,14 @@ build_lollygag_get_owner_object_request, build_lollygag_insert_owner_object_request, build_pilots_add_pilot_stamps_request, + build_pilots_create_pilot_secrets_request, build_pilots_delete_pilots_request, build_pilots_get_pilot_jobs_request, + build_pilots_internal_userinfo_request, build_pilots_search_request, build_pilots_summary_request, build_pilots_update_pilot_fields_request, + build_pilots_update_secrets_constraints_request, build_well_known_get_installation_metadata_request, build_well_known_get_jwks_request, build_well_known_get_openid_configuration_request, @@ -906,6 +911,210 @@ async def complete_authorization_flow(self, *, code: str, state: str, **kwargs: return deserialized # type: ignore + @overload + async def perform_secret_exchange( + self, body: _models.PilotCredentials, *, content_type: str = "application/json", **kwargs: Any + ) -> _models.TokenResponse: + """Perform Secret Exchange. + + This endpoint is used by the pilot to exchange a secret for a token. + + This endpoint also acts as DIRAC's ``dirac-admin-add-pilot``. + + :param body: Required. + :type body: ~_generated.models.PilotCredentials + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: TokenResponse + :rtype: ~_generated.models.TokenResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def perform_secret_exchange( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> _models.TokenResponse: + """Perform Secret Exchange. + + This endpoint is used by the pilot to exchange a secret for a token. + + This endpoint also acts as DIRAC's ``dirac-admin-add-pilot``. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: TokenResponse + :rtype: ~_generated.models.TokenResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def perform_secret_exchange( + self, body: Union[_models.PilotCredentials, IO[bytes]], **kwargs: Any + ) -> _models.TokenResponse: + """Perform Secret Exchange. + + This endpoint is used by the pilot to exchange a secret for a token. + + This endpoint also acts as DIRAC's ``dirac-admin-add-pilot``. + + :param body: Is either a PilotCredentials type or a IO[bytes] type. Required. + :type body: ~_generated.models.PilotCredentials or IO[bytes] + :return: TokenResponse + :rtype: ~_generated.models.TokenResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[_models.TokenResponse] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "PilotCredentials") + + _request = build_auth_perform_secret_exchange_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("TokenResponse", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @overload + async def refresh_pilot_tokens( + self, body: _models.BodyAuthRefreshPilotTokens, *, content_type: str = "application/json", **kwargs: Any + ) -> _models.TokenResponse: + """Refresh Pilot Tokens. + + Endpoint where *only* pilots can exchange a refresh token for a token. + + :param body: Required. + :type body: ~_generated.models.BodyAuthRefreshPilotTokens + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: TokenResponse + :rtype: ~_generated.models.TokenResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def refresh_pilot_tokens( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> _models.TokenResponse: + """Refresh Pilot Tokens. + + Endpoint where *only* pilots can exchange a refresh token for a token. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: TokenResponse + :rtype: ~_generated.models.TokenResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def refresh_pilot_tokens( + self, body: Union[_models.BodyAuthRefreshPilotTokens, IO[bytes]], **kwargs: Any + ) -> _models.TokenResponse: + """Refresh Pilot Tokens. + + Endpoint where *only* pilots can exchange a refresh token for a token. + + :param body: Is either a BodyAuthRefreshPilotTokens type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyAuthRefreshPilotTokens or IO[bytes] + :return: TokenResponse + :rtype: ~_generated.models.TokenResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[_models.TokenResponse] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyAuthRefreshPilotTokens") + + _request = build_auth_refresh_pilot_tokens_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("TokenResponse", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + class ConfigOperations: """ @@ -2553,7 +2762,7 @@ def __init__(self, *args, **kwargs) -> None: @overload async def add_pilot_stamps( self, body: _models.BodyPilotsAddPilotStamps, *, content_type: str = "application/json", **kwargs: Any - ) -> Any: + ) -> Optional[List[_models.PilotCredentialsInfo]]: """Add Pilot Stamps. Endpoint where a you can create pilots with their references. @@ -2565,13 +2774,15 @@ async def add_pilot_stamps( :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. Default value is "application/json". :paramtype content_type: str - :return: any - :rtype: any + :return: list of PilotCredentialsInfo or None + :rtype: list[~_generated.models.PilotCredentialsInfo] or None :raises ~azure.core.exceptions.HttpResponseError: """ @overload - async def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + async def add_pilot_stamps( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> Optional[List[_models.PilotCredentialsInfo]]: """Add Pilot Stamps. Endpoint where a you can create pilots with their references. @@ -2583,13 +2794,15 @@ async def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "applic :keyword content_type: Body Parameter content-type. Content type parameter for binary body. Default value is "application/json". :paramtype content_type: str - :return: any - :rtype: any + :return: list of PilotCredentialsInfo or None + :rtype: list[~_generated.models.PilotCredentialsInfo] or None :raises ~azure.core.exceptions.HttpResponseError: """ @distributed_trace_async - async def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any) -> Any: + async def add_pilot_stamps( + self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any + ) -> Optional[List[_models.PilotCredentialsInfo]]: """Add Pilot Stamps. Endpoint where a you can create pilots with their references. @@ -2598,8 +2811,8 @@ async def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, I :param body: Is either a BodyPilotsAddPilotStamps type or a IO[bytes] type. Required. :type body: ~_generated.models.BodyPilotsAddPilotStamps or IO[bytes] - :return: any - :rtype: any + :return: list of PilotCredentialsInfo or None + :rtype: list[~_generated.models.PilotCredentialsInfo] or None :raises ~azure.core.exceptions.HttpResponseError: """ error_map: MutableMapping = { @@ -2614,7 +2827,7 @@ async def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, I _params = kwargs.pop("params", {}) or {} content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Any] = kwargs.pop("cls", None) + cls: ClsType[Optional[List[_models.PilotCredentialsInfo]]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -2644,7 +2857,7 @@ async def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, I map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) - deserialized = self._deserialize("object", pipeline_response.http_response) + deserialized = self._deserialize("[PilotCredentialsInfo]", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -2722,6 +2935,200 @@ async def delete_pilots( if cls: return cls(pipeline_response, None, {}) # type: ignore + @overload + async def create_pilot_secrets( + self, body: _models.BodyPilotsCreatePilotSecrets, *, content_type: str = "application/json", **kwargs: Any + ) -> List[_models.PilotSecretsInfo]: + """Create Pilot Secrets. + + Endpoint to create secrets. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsCreatePilotSecrets + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of PilotSecretsInfo + :rtype: list[~_generated.models.PilotSecretsInfo] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def create_pilot_secrets( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> List[_models.PilotSecretsInfo]: + """Create Pilot Secrets. + + Endpoint to create secrets. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of PilotSecretsInfo + :rtype: list[~_generated.models.PilotSecretsInfo] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def create_pilot_secrets( + self, body: Union[_models.BodyPilotsCreatePilotSecrets, IO[bytes]], **kwargs: Any + ) -> List[_models.PilotSecretsInfo]: + """Create Pilot Secrets. + + Endpoint to create secrets. + + :param body: Is either a BodyPilotsCreatePilotSecrets type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsCreatePilotSecrets or IO[bytes] + :return: list of PilotSecretsInfo + :rtype: list[~_generated.models.PilotSecretsInfo] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[_models.PilotSecretsInfo]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsCreatePilotSecrets") + + _request = build_pilots_create_pilot_secrets_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("[PilotSecretsInfo]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @overload + async def update_secrets_constraints( + self, body: Dict[str, _models.PilotSecretConstraints], *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Secrets Constraints. + + Endpoint to associate pilots with secrets. + + :param body: Required. + :type body: dict[str, ~_generated.models.PilotSecretConstraints] + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def update_secrets_constraints( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Secrets Constraints. + + Endpoint to associate pilots with secrets. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def update_secrets_constraints( + self, body: Union[Dict[str, _models.PilotSecretConstraints], IO[bytes]], **kwargs: Any + ) -> None: + """Update Secrets Constraints. + + Endpoint to associate pilots with secrets. + + :param body: Is either a {str: PilotSecretConstraints} type or a IO[bytes] type. Required. + :type body: dict[str, ~_generated.models.PilotSecretConstraints] or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "{PilotSecretConstraints}") + + _request = build_pilots_update_secrets_constraints_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + @overload async def update_pilot_fields( self, body: _models.BodyPilotsUpdatePilotFields, *, content_type: str = "application/json", **kwargs: Any @@ -3109,3 +3516,73 @@ async def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsInternalOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.aio.Dirac`'s + :attr:`pilots_internal` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: AsyncPipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @distributed_trace_async + async def userinfo(self, *, authorization: Optional[str] = None, **kwargs: Any) -> _models.PilotInfo: + """Userinfo. + + Get information about the user's identity. + + :keyword authorization: Default value is None. + :paramtype authorization: str + :return: PilotInfo + :rtype: ~_generated.models.PilotInfo + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.PilotInfo] = kwargs.pop("cls", None) + + _request = build_pilots_internal_userinfo_request( + authorization=authorization, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("PilotInfo", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py index 7bdd59b63..a0e841992 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py @@ -16,7 +16,9 @@ BodyAuthGetOidcTokenGrantType, BodyJobsRescheduleJobs, BodyJobsUnassignBulkJobsSandboxes, + BodyAuthRefreshPilotTokens, BodyPilotsAddPilotStamps, + BodyPilotsCreatePilotSecrets, BodyPilotsUpdatePilotFields, ExtendedMetadata, GroupInfo, @@ -29,7 +31,13 @@ JobMetaDataAccountedFlag, JobStatusUpdate, OpenIDConfiguration, + PilotAuthCredentials, + PilotCredentials, + PilotCredentialsInfo, PilotFieldsMapping, + PilotInfo, + PilotSecretConstraints, + PilotSecretsInfo, SandboxDownloadResponse, SandboxInfo, SandboxUploadResponse, @@ -46,6 +54,7 @@ TokenResponse, UserInfoResponse, VOInfo, + VacuumPilotAuth, ValidationError, ValidationErrorLocItem, VectorSearchSpec, @@ -71,7 +80,9 @@ "BodyAuthGetOidcTokenGrantType", "BodyJobsRescheduleJobs", "BodyJobsUnassignBulkJobsSandboxes", + "BodyAuthRefreshPilotTokens", "BodyPilotsAddPilotStamps", + "BodyPilotsCreatePilotSecrets", "BodyPilotsUpdatePilotFields", "ExtendedMetadata", "GroupInfo", @@ -84,7 +95,13 @@ "JobMetaDataAccountedFlag", "JobStatusUpdate", "OpenIDConfiguration", + "PilotAuthCredentials", + "PilotCredentials", + "PilotCredentialsInfo", "PilotFieldsMapping", + "PilotInfo", + "PilotSecretConstraints", + "PilotSecretsInfo", "SandboxDownloadResponse", "SandboxInfo", "SandboxUploadResponse", @@ -101,6 +118,7 @@ "TokenResponse", "UserInfoResponse", "VOInfo", + "VacuumPilotAuth", "ValidationError", "ValidationErrorLocItem", "VectorSearchSpec", diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py index 2e8717cb6..f07d962f1 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py @@ -144,6 +144,37 @@ def __init__(self, *, job_ids: List[int], **kwargs: Any) -> None: """ super().__init__(**kwargs) self.job_ids = job_ids +class BodyAuthRefreshPilotTokens(_serialization.Model): + """Body_auth_refresh_pilot_tokens. + + All required parameters must be populated in order to send to server. + + :ivar refresh_token: Refresh Token given at login by DiracX. Required. + :vartype refresh_token: str + :ivar pilot_stamp: Pilot stamp. Required. + :vartype pilot_stamp: str + """ + + _validation = { + "refresh_token": {"required": True}, + "pilot_stamp": {"required": True}, + } + + _attribute_map = { + "refresh_token": {"key": "refresh_token", "type": "str"}, + "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, + } + + def __init__(self, *, refresh_token: str, pilot_stamp: str, **kwargs: Any) -> None: + """ + :keyword refresh_token: Refresh Token given at login by DiracX. Required. + :paramtype refresh_token: str + :keyword pilot_stamp: Pilot stamp. Required. + :paramtype pilot_stamp: str + """ + super().__init__(**kwargs) + self.refresh_token = refresh_token + self.pilot_stamp = pilot_stamp class BodyPilotsAddPilotStamps(_serialization.Model): @@ -166,6 +197,10 @@ class BodyPilotsAddPilotStamps(_serialization.Model): :ivar pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", "Running", "Done", "Failed", "Deleted", "Aborted", and "Unknown". :vartype pilot_status: str or ~_generated.models.PilotStatus + :ivar generate_secrets: If we want to create secrets with the pilots. + :vartype generate_secrets: bool + :ivar pilot_secret_use_count_max: How much time can a secret be used. + :vartype pilot_secret_use_count_max: int """ _validation = { @@ -181,6 +216,8 @@ class BodyPilotsAddPilotStamps(_serialization.Model): "destination_site": {"key": "destination_site", "type": "str"}, "pilot_references": {"key": "pilot_references", "type": "{str}"}, "pilot_status": {"key": "pilot_status", "type": "str"}, + "generate_secrets": {"key": "generate_secrets", "type": "bool"}, + "pilot_secret_use_count_max": {"key": "pilot_secret_use_count_max", "type": "int"}, } def __init__( @@ -193,6 +230,8 @@ def __init__( destination_site: str = "NotAssigned", pilot_references: Optional[Dict[str, str]] = None, pilot_status: Optional[Union[str, "_models.PilotStatus"]] = None, + generate_secrets: bool = True, + pilot_secret_use_count_max: int = 1, **kwargs: Any ) -> None: """ @@ -211,6 +250,10 @@ def __init__( :keyword pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", "Running", "Done", "Failed", "Deleted", "Aborted", and "Unknown". :paramtype pilot_status: str or ~_generated.models.PilotStatus + :keyword generate_secrets: If we want to create secrets with the pilots. + :paramtype generate_secrets: bool + :keyword pilot_secret_use_count_max: How much time can a secret be used. + :paramtype pilot_secret_use_count_max: int """ super().__init__(**kwargs) self.pilot_stamps = pilot_stamps @@ -220,6 +263,57 @@ def __init__( self.destination_site = destination_site self.pilot_references = pilot_references self.pilot_status = pilot_status + self.generate_secrets = generate_secrets + self.pilot_secret_use_count_max = pilot_secret_use_count_max + + +class BodyPilotsCreatePilotSecrets(_serialization.Model): + """Body_pilots_create_pilot_secrets. + + All required parameters must be populated in order to send to server. + + :ivar n: Number of secrets to create. Required. + :vartype n: int + :ivar expiration_minutes: Time in minutes before expiring. Required. + :vartype expiration_minutes: int + :ivar pilot_secret_use_count_max: Number of times that we can use a secret. Required. + :vartype pilot_secret_use_count_max: int + :ivar vo: Only VO that can access a secret. Required. + :vartype vo: str + """ + + _validation = { + "n": {"required": True}, + "expiration_minutes": {"required": True}, + "pilot_secret_use_count_max": {"required": True}, + "vo": {"required": True}, + } + + _attribute_map = { + "n": {"key": "n", "type": "int"}, + "expiration_minutes": {"key": "expiration_minutes", "type": "int"}, + "pilot_secret_use_count_max": {"key": "pilot_secret_use_count_max", "type": "int"}, + "vo": {"key": "vo", "type": "str"}, + } + + def __init__( + self, *, n: int, expiration_minutes: int, pilot_secret_use_count_max: int, vo: str, **kwargs: Any + ) -> None: + """ + :keyword n: Number of secrets to create. Required. + :paramtype n: int + :keyword expiration_minutes: Time in minutes before expiring. Required. + :paramtype expiration_minutes: int + :keyword pilot_secret_use_count_max: Number of times that we can use a secret. Required. + :paramtype pilot_secret_use_count_max: int + :keyword vo: Only VO that can access a secret. Required. + :paramtype vo: str + """ + super().__init__(**kwargs) + self.n = n + self.expiration_minutes = expiration_minutes + self.pilot_secret_use_count_max = pilot_secret_use_count_max + self.vo = vo class BodyPilotsUpdatePilotFields(_serialization.Model): @@ -1010,6 +1104,83 @@ def __init__( self.code_challenge_methods_supported = code_challenge_methods_supported +class PilotAuthCredentials(_serialization.Model): + """PilotAuthCredentials. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamp: Pilot Stamp. Required. + :vartype pilot_stamp: str + :ivar pilot_secret: Pilot Secret. Required. + :vartype pilot_secret: str + """ + + _validation = { + "pilot_stamp": {"required": True}, + "pilot_secret": {"required": True}, + } + + _attribute_map = { + "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, + "pilot_secret": {"key": "pilot_secret", "type": "str"}, + } + + def __init__(self, *, pilot_stamp: str, pilot_secret: str, **kwargs: Any) -> None: + """ + :keyword pilot_stamp: Pilot Stamp. Required. + :paramtype pilot_stamp: str + :keyword pilot_secret: Pilot Secret. Required. + :paramtype pilot_secret: str + """ + super().__init__(**kwargs) + self.pilot_stamp = pilot_stamp + self.pilot_secret = pilot_secret + + +class PilotCredentials(_serialization.Model): + """Pilot credentials (stamp and secret).""" + + +class PilotCredentialsInfo(_serialization.Model): + """PilotCredentialsInfo. + + All required parameters must be populated in order to send to server. + + :ivar pilot_secret: Pilot Secret. Required. + :vartype pilot_secret: str + :ivar pilot_secret_expires_in: Pilot Secret Expires In. Required. + :vartype pilot_secret_expires_in: int + :ivar pilot_stamp: Pilot Stamp. Required. + :vartype pilot_stamp: str + """ + + _validation = { + "pilot_secret": {"required": True}, + "pilot_secret_expires_in": {"required": True}, + "pilot_stamp": {"required": True}, + } + + _attribute_map = { + "pilot_secret": {"key": "pilot_secret", "type": "str"}, + "pilot_secret_expires_in": {"key": "pilot_secret_expires_in", "type": "int"}, + "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, + } + + def __init__(self, *, pilot_secret: str, pilot_secret_expires_in: int, pilot_stamp: str, **kwargs: Any) -> None: + """ + :keyword pilot_secret: Pilot Secret. Required. + :paramtype pilot_secret: str + :keyword pilot_secret_expires_in: Pilot Secret Expires In. Required. + :paramtype pilot_secret_expires_in: int + :keyword pilot_stamp: Pilot Stamp. Required. + :paramtype pilot_stamp: str + """ + super().__init__(**kwargs) + self.pilot_secret = pilot_secret + self.pilot_secret_expires_in = pilot_secret_expires_in + self.pilot_stamp = pilot_stamp + + class PilotFieldsMapping(_serialization.Model): """All the fields that a user can modify on a Pilot (except PilotStamp). @@ -1106,6 +1277,118 @@ def __init__( self.current_job_id = current_job_id +class PilotInfo(_serialization.Model): + """PilotInfo. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamp: Pilot Stamp. Required. + :vartype pilot_stamp: str + :ivar vo: Vo. Required. + :vartype vo: str + :ivar sub: Sub. Required. + :vartype sub: str + """ + + _validation = { + "pilot_stamp": {"required": True}, + "vo": {"required": True}, + "sub": {"required": True}, + } + + _attribute_map = { + "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, + "vo": {"key": "vo", "type": "str"}, + "sub": {"key": "sub", "type": "str"}, + } + + def __init__(self, *, pilot_stamp: str, vo: str, sub: str, **kwargs: Any) -> None: + """ + :keyword pilot_stamp: Pilot Stamp. Required. + :paramtype pilot_stamp: str + :keyword vo: Vo. Required. + :paramtype vo: str + :keyword sub: Sub. Required. + :paramtype sub: str + """ + super().__init__(**kwargs) + self.pilot_stamp = pilot_stamp + self.vo = vo + self.sub = sub + + +class PilotSecretConstraints(_serialization.Model): + """PilotSecretConstraints. + + :ivar v_os: Vos. + :vartype v_os: list[str] + :ivar pilot_stamps: Pilotstamps. + :vartype pilot_stamps: list[str] + :ivar sites: Sites. + :vartype sites: list[str] + """ + + _attribute_map = { + "v_os": {"key": "VOs", "type": "[str]"}, + "pilot_stamps": {"key": "PilotStamps", "type": "[str]"}, + "sites": {"key": "Sites", "type": "[str]"}, + } + + def __init__( + self, + *, + v_os: Optional[List[str]] = None, + pilot_stamps: Optional[List[str]] = None, + sites: Optional[List[str]] = None, + **kwargs: Any + ) -> None: + """ + :keyword v_os: Vos. + :paramtype v_os: list[str] + :keyword pilot_stamps: Pilotstamps. + :paramtype pilot_stamps: list[str] + :keyword sites: Sites. + :paramtype sites: list[str] + """ + super().__init__(**kwargs) + self.v_os = v_os + self.pilot_stamps = pilot_stamps + self.sites = sites + + +class PilotSecretsInfo(_serialization.Model): + """PilotSecretsInfo. + + All required parameters must be populated in order to send to server. + + :ivar pilot_secret: Pilot Secret. Required. + :vartype pilot_secret: str + :ivar pilot_secret_expires_in: Pilot Secret Expires In. Required. + :vartype pilot_secret_expires_in: int + """ + + _validation = { + "pilot_secret": {"required": True}, + "pilot_secret_expires_in": {"required": True}, + } + + _attribute_map = { + "pilot_secret": {"key": "pilot_secret", "type": "str"}, + "pilot_secret_expires_in": {"key": "pilot_secret_expires_in", "type": "int"}, + } + + def __init__(self, *, pilot_secret: str, pilot_secret_expires_in: int, **kwargs: Any) -> None: + """ + :keyword pilot_secret: Pilot Secret. Required. + :paramtype pilot_secret: str + :keyword pilot_secret_expires_in: Pilot Secret Expires In. Required. + :paramtype pilot_secret_expires_in: int + """ + super().__init__(**kwargs) + self.pilot_secret = pilot_secret + self.pilot_secret_expires_in = pilot_secret_expires_in + + class SandboxDownloadResponse(_serialization.Model): """SandboxDownloadResponse. @@ -1682,6 +1965,77 @@ def __init__( self.preferred_username = preferred_username +class VacuumPilotAuth(_serialization.Model): + """VacuumPilotAuth. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamp: Pilot Stamp. Required. + :vartype pilot_stamp: str + :ivar pilot_secret: Pilot Secret. Required. + :vartype pilot_secret: str + :ivar vo: Vo. Required. + :vartype vo: str + :ivar grid_type: Grid Type. Required. + :vartype grid_type: str + :ivar grid_site: Grid Site. Required. + :vartype grid_site: str + :ivar status: Status. Required. + :vartype status: str + """ + + _validation = { + "pilot_stamp": {"required": True}, + "pilot_secret": {"required": True}, + "vo": {"required": True}, + "grid_type": {"required": True}, + "grid_site": {"required": True}, + "status": {"required": True}, + } + + _attribute_map = { + "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, + "pilot_secret": {"key": "pilot_secret", "type": "str"}, + "vo": {"key": "vo", "type": "str"}, + "grid_type": {"key": "grid_type", "type": "str"}, + "grid_site": {"key": "grid_site", "type": "str"}, + "status": {"key": "status", "type": "str"}, + } + + def __init__( + self, + *, + pilot_stamp: str, + pilot_secret: str, + vo: str, + grid_type: str, + grid_site: str, + status: str, + **kwargs: Any + ) -> None: + """ + :keyword pilot_stamp: Pilot Stamp. Required. + :paramtype pilot_stamp: str + :keyword pilot_secret: Pilot Secret. Required. + :paramtype pilot_secret: str + :keyword vo: Vo. Required. + :paramtype vo: str + :keyword grid_type: Grid Type. Required. + :paramtype grid_type: str + :keyword grid_site: Grid Site. Required. + :paramtype grid_site: str + :keyword status: Status. Required. + :paramtype status: str + """ + super().__init__(**kwargs) + self.pilot_stamp = pilot_stamp + self.pilot_secret = pilot_secret + self.vo = vo + self.grid_type = grid_type + self.grid_site = grid_site + self.status = status + + class ValidationError(_serialization.Model): """ValidationError. diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py index 3408891fc..9eb133a2d 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py @@ -16,6 +16,7 @@ from ._operations import JobsOperations # type: ignore from ._operations import LollygagOperations # type: ignore from ._operations import PilotsOperations # type: ignore +from ._operations import PilotsInternalOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -28,6 +29,7 @@ "JobsOperations", "LollygagOperations", "PilotsOperations", + "PilotsInternalOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py index 4358ecf51..9bc744318 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py @@ -286,6 +286,40 @@ def build_auth_complete_authorization_flow_request( # pylint: disable=name-too- return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) +def build_auth_perform_secret_exchange_request(**kwargs: Any) -> HttpRequest: # pylint: disable=name-too-long + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/auth/secret-exchange" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + +def build_auth_refresh_pilot_tokens_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/auth/pilot-token" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + def build_config_serve_config_request( *, if_modified_since: Optional[str] = None, @@ -675,6 +709,37 @@ def build_pilots_delete_pilots_request( return HttpRequest(method="DELETE", url=_url, params=_params, **kwargs) +def build_pilots_create_pilot_secrets_request(**kwargs: Any) -> HttpRequest: # pylint: disable=name-too-long + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/secrets" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + +def build_pilots_update_secrets_constraints_request(**kwargs: Any) -> HttpRequest: # pylint: disable=name-too-long + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + # Construct URL + _url = "/api/pilots/secrets" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + + return HttpRequest(method="PATCH", url=_url, headers=_headers, **kwargs) + + def build_pilots_update_pilot_fields_request(**kwargs: Any) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) @@ -753,6 +818,22 @@ def build_pilots_summary_request(**kwargs: Any) -> HttpRequest: return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) +def build_pilots_internal_userinfo_request(*, authorization: Optional[str] = None, **kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/internal/pilotinfo" + + # Construct headers + if authorization is not None: + _headers["authorization"] = _SERIALIZER.header("authorization", authorization, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) + + class WellKnownOperations: """ .. warning:: @@ -1585,6 +1666,210 @@ def complete_authorization_flow(self, *, code: str, state: str, **kwargs: Any) - return deserialized # type: ignore + @overload + def perform_secret_exchange( + self, body: _models.PilotCredentials, *, content_type: str = "application/json", **kwargs: Any + ) -> _models.TokenResponse: + """Perform Secret Exchange. + + This endpoint is used by the pilot to exchange a secret for a token. + + This endpoint also acts as DIRAC's ``dirac-admin-add-pilot``. + + :param body: Required. + :type body: ~_generated.models.PilotCredentials + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: TokenResponse + :rtype: ~_generated.models.TokenResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def perform_secret_exchange( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> _models.TokenResponse: + """Perform Secret Exchange. + + This endpoint is used by the pilot to exchange a secret for a token. + + This endpoint also acts as DIRAC's ``dirac-admin-add-pilot``. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: TokenResponse + :rtype: ~_generated.models.TokenResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def perform_secret_exchange( + self, body: Union[_models.PilotCredentials, IO[bytes]], **kwargs: Any + ) -> _models.TokenResponse: + """Perform Secret Exchange. + + This endpoint is used by the pilot to exchange a secret for a token. + + This endpoint also acts as DIRAC's ``dirac-admin-add-pilot``. + + :param body: Is either a PilotCredentials type or a IO[bytes] type. Required. + :type body: ~_generated.models.PilotCredentials or IO[bytes] + :return: TokenResponse + :rtype: ~_generated.models.TokenResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[_models.TokenResponse] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "PilotCredentials") + + _request = build_auth_perform_secret_exchange_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("TokenResponse", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @overload + def refresh_pilot_tokens( + self, body: _models.BodyAuthRefreshPilotTokens, *, content_type: str = "application/json", **kwargs: Any + ) -> _models.TokenResponse: + """Refresh Pilot Tokens. + + Endpoint where *only* pilots can exchange a refresh token for a token. + + :param body: Required. + :type body: ~_generated.models.BodyAuthRefreshPilotTokens + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: TokenResponse + :rtype: ~_generated.models.TokenResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def refresh_pilot_tokens( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> _models.TokenResponse: + """Refresh Pilot Tokens. + + Endpoint where *only* pilots can exchange a refresh token for a token. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: TokenResponse + :rtype: ~_generated.models.TokenResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def refresh_pilot_tokens( + self, body: Union[_models.BodyAuthRefreshPilotTokens, IO[bytes]], **kwargs: Any + ) -> _models.TokenResponse: + """Refresh Pilot Tokens. + + Endpoint where *only* pilots can exchange a refresh token for a token. + + :param body: Is either a BodyAuthRefreshPilotTokens type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyAuthRefreshPilotTokens or IO[bytes] + :return: TokenResponse + :rtype: ~_generated.models.TokenResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[_models.TokenResponse] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyAuthRefreshPilotTokens") + + _request = build_auth_refresh_pilot_tokens_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("TokenResponse", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + class ConfigOperations: """ @@ -3230,7 +3515,7 @@ def __init__(self, *args, **kwargs) -> None: @overload def add_pilot_stamps( self, body: _models.BodyPilotsAddPilotStamps, *, content_type: str = "application/json", **kwargs: Any - ) -> Any: + ) -> Optional[List[_models.PilotCredentialsInfo]]: """Add Pilot Stamps. Endpoint where a you can create pilots with their references. @@ -3242,13 +3527,15 @@ def add_pilot_stamps( :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. Default value is "application/json". :paramtype content_type: str - :return: any - :rtype: any + :return: list of PilotCredentialsInfo or None + :rtype: list[~_generated.models.PilotCredentialsInfo] or None :raises ~azure.core.exceptions.HttpResponseError: """ @overload - def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + def add_pilot_stamps( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> Optional[List[_models.PilotCredentialsInfo]]: """Add Pilot Stamps. Endpoint where a you can create pilots with their references. @@ -3260,13 +3547,15 @@ def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/ :keyword content_type: Body Parameter content-type. Content type parameter for binary body. Default value is "application/json". :paramtype content_type: str - :return: any - :rtype: any + :return: list of PilotCredentialsInfo or None + :rtype: list[~_generated.models.PilotCredentialsInfo] or None :raises ~azure.core.exceptions.HttpResponseError: """ @distributed_trace - def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any) -> Any: + def add_pilot_stamps( + self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any + ) -> Optional[List[_models.PilotCredentialsInfo]]: """Add Pilot Stamps. Endpoint where a you can create pilots with their references. @@ -3275,8 +3564,8 @@ def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[byte :param body: Is either a BodyPilotsAddPilotStamps type or a IO[bytes] type. Required. :type body: ~_generated.models.BodyPilotsAddPilotStamps or IO[bytes] - :return: any - :rtype: any + :return: list of PilotCredentialsInfo or None + :rtype: list[~_generated.models.PilotCredentialsInfo] or None :raises ~azure.core.exceptions.HttpResponseError: """ error_map: MutableMapping = { @@ -3291,7 +3580,7 @@ def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[byte _params = kwargs.pop("params", {}) or {} content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Any] = kwargs.pop("cls", None) + cls: ClsType[Optional[List[_models.PilotCredentialsInfo]]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -3321,7 +3610,7 @@ def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[byte map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) - deserialized = self._deserialize("object", pipeline_response.http_response) + deserialized = self._deserialize("[PilotCredentialsInfo]", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -3399,6 +3688,200 @@ def delete_pilots( # pylint: disable=inconsistent-return-statements if cls: return cls(pipeline_response, None, {}) # type: ignore + @overload + def create_pilot_secrets( + self, body: _models.BodyPilotsCreatePilotSecrets, *, content_type: str = "application/json", **kwargs: Any + ) -> List[_models.PilotSecretsInfo]: + """Create Pilot Secrets. + + Endpoint to create secrets. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsCreatePilotSecrets + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of PilotSecretsInfo + :rtype: list[~_generated.models.PilotSecretsInfo] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def create_pilot_secrets( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> List[_models.PilotSecretsInfo]: + """Create Pilot Secrets. + + Endpoint to create secrets. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of PilotSecretsInfo + :rtype: list[~_generated.models.PilotSecretsInfo] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def create_pilot_secrets( + self, body: Union[_models.BodyPilotsCreatePilotSecrets, IO[bytes]], **kwargs: Any + ) -> List[_models.PilotSecretsInfo]: + """Create Pilot Secrets. + + Endpoint to create secrets. + + :param body: Is either a BodyPilotsCreatePilotSecrets type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsCreatePilotSecrets or IO[bytes] + :return: list of PilotSecretsInfo + :rtype: list[~_generated.models.PilotSecretsInfo] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[_models.PilotSecretsInfo]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsCreatePilotSecrets") + + _request = build_pilots_create_pilot_secrets_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("[PilotSecretsInfo]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @overload + def update_secrets_constraints( + self, body: Dict[str, _models.PilotSecretConstraints], *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Secrets Constraints. + + Endpoint to associate pilots with secrets. + + :param body: Required. + :type body: dict[str, ~_generated.models.PilotSecretConstraints] + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def update_secrets_constraints( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Secrets Constraints. + + Endpoint to associate pilots with secrets. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def update_secrets_constraints( # pylint: disable=inconsistent-return-statements + self, body: Union[Dict[str, _models.PilotSecretConstraints], IO[bytes]], **kwargs: Any + ) -> None: + """Update Secrets Constraints. + + Endpoint to associate pilots with secrets. + + :param body: Is either a {str: PilotSecretConstraints} type or a IO[bytes] type. Required. + :type body: dict[str, ~_generated.models.PilotSecretConstraints] or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "{PilotSecretConstraints}") + + _request = build_pilots_update_secrets_constraints_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + @overload def update_pilot_fields( self, body: _models.BodyPilotsUpdatePilotFields, *, content_type: str = "application/json", **kwargs: Any @@ -3782,3 +4265,73 @@ def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsInternalOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.Dirac`'s + :attr:`pilots_internal` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: PipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @distributed_trace + def userinfo(self, *, authorization: Optional[str] = None, **kwargs: Any) -> _models.PilotInfo: + """Userinfo. + + Get information about the user's identity. + + :keyword authorization: Default value is None. + :paramtype authorization: str + :return: PilotInfo + :rtype: ~_generated.models.PilotInfo + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.PilotInfo] = kwargs.pop("cls", None) + + _request = build_pilots_internal_userinfo_request( + authorization=authorization, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("PilotInfo", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore From e07397bbf70a2a3ea7ce4f5eb1a98c7235baec71 Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Tue, 5 Aug 2025 14:20:05 +0200 Subject: [PATCH 07/11] fix: Moved secrets table to the AuthDB --- diracx-db/src/diracx/db/sql/auth/db.py | 157 +++++++++++++++- diracx-db/src/diracx/db/sql/auth/schema.py | 29 +++ diracx-db/src/diracx/db/sql/pilots/db.py | 155 +--------------- diracx-db/src/diracx/db/sql/pilots/schema.py | 30 --- .../tests/auth/test_authorization_flow.py | 2 +- diracx-db/tests/auth/test_device_flow.py | 2 +- diracx-db/tests/auth/test_refresh_token.py | 2 +- diracx-db/tests/pilots/test_pilot_auth.py | 166 ----------------- .../tests/pilots/test_pilot_management.py | 2 +- diracx-db/tests/pilots/test_query.py | 2 +- diracx-db/tests/pilots/utils.py | 173 +----------------- diracx-logic/src/diracx/logic/pilots/auth.py | 26 +-- .../src/diracx/logic/pilots/management.py | 7 +- diracx-logic/src/diracx/logic/pilots/query.py | 10 +- .../diracx/routers/pilots/access_policies.py | 2 +- .../src/diracx/routers/pilots/management.py | 12 +- .../tests/pilots/test_pilot_auth.py | 12 +- diracx-routers/tests/pilots/test_query.py | 1 + 18 files changed, 230 insertions(+), 560 deletions(-) delete mode 100644 diracx-db/tests/pilots/test_pilot_auth.py diff --git a/diracx-db/src/diracx/db/sql/auth/db.py b/diracx-db/src/diracx/db/sql/auth/db.py index 58e358370..5679364a0 100644 --- a/diracx-db/src/diracx/db/sql/auth/db.py +++ b/diracx-db/src/diracx/db/sql/auth/db.py @@ -1,21 +1,26 @@ from __future__ import annotations import secrets +from typing import Any -from sqlalchemy import insert, select, update +from sqlalchemy import DateTime, bindparam, delete, insert, select, update from sqlalchemy.exc import IntegrityError, NoResultFound from uuid_utils import UUID, uuid7 from diracx.core.exceptions import ( AuthorizationError, + SecretNotFoundError, TokenNotFoundError, ) +from diracx.core.models import PilotSecretConstraints, SearchSpec, SortSpec from diracx.db.sql.utils import BaseSQLDB, hash, substract_date +from diracx.db.sql.utils.functions import utcnow from .schema import ( AuthorizationFlows, DeviceFlows, FlowStatus, + PilotSecrets, RefreshTokens, RefreshTokenStatus, ) @@ -264,3 +269,153 @@ async def revoke_user_refresh_tokens(self, subject): .where(RefreshTokens.sub == subject) .values(status=RefreshTokenStatus.REVOKED) ) + + # ------------- Pilot secrets mechanism ------------- + + async def insert_unique_secrets( + self, + hashed_secrets: list[bytes], + secret_global_use_count_max: int | None = 1, + secret_constraints: dict[bytes, PilotSecretConstraints] = {}, + ): + """Bulk insert secrets. + + Raises: + - NotImplementedError if we have an IntegrityError not caught + + """ + values = [ + { + "SecretUUID": str(uuid7()), + "SecretRemainingUseCount": secret_global_use_count_max, + "HashedSecret": hashed_secret, + "SecretConstraints": secret_constraints.get(hashed_secret, {}), + } + for hashed_secret in hashed_secrets + ] + + stmt = insert(PilotSecrets).values(values) + await self.conn.execute(stmt) + + async def delete_secrets(self, secret_uuids: list[str]): + """Bulk delete secrets. + + Raises SecretNotFoundError if one of the secret was not found. + """ + stmt = delete(PilotSecrets).where(PilotSecrets.secret_uuid.in_(secret_uuids)) + + res = await self.conn.execute(stmt) + + if res.rowcount != len(secret_uuids): + raise SecretNotFoundError( + "At least one of the secret has not been deleted." + ) + + # We NEED to commit here, because we will raise an error after this function + await self.conn.commit() + + async def update_pilot_secret_use_time(self, secret_uuid: str) -> None: + """Updates when a pilot uses a secret. + + Raises PilotNotFoundError if the pilot does not exist + + """ + # Prepare the update statement + stmt = ( + update(PilotSecrets) + .values( + pilot_secret_use_date=utcnow(), + secret_remaining_use_count=PilotSecrets.secret_remaining_use_count - 1, + ) + .where(PilotSecrets.secret_uuid == secret_uuid) + ) + + # Execute the update using the connection + res = await self.conn.execute(stmt) + + if res.rowcount == 0: + raise SecretNotFoundError("Unknown secret") + + async def update_pilot_secrets_constraints( + self, hashed_secrets_to_pilot_stamps_mapping: list[dict[str, Any]] + ): + """Bulk associate pilots with secrets by updating theirs constraints. + + Important: We have to provide the updated constraints. + + Raises: + - PilotNotFoundError if one of the pilot does not exist + - NotImplementedError if at least of the pilot + + """ + # Better to give as a parameter pilot to secret associations, rather than associating here. + + stmt = ( + update(PilotSecrets) + .where(PilotSecrets.hashed_secret == bindparam("PilotHashedSecret")) + .values({"SecretConstraints": bindparam("PilotSecretConstraints")}) + ) + + try: + await self.conn.execute(stmt, hashed_secrets_to_pilot_stamps_mapping) + except IntegrityError as e: + if "foreign key" in str(e.orig).lower(): + raise SecretNotFoundError( + detail="at least one of these secrets does not exist", + ) from e + raise NotImplementedError(f"This error is not caught: {str(e.orig)}") from e + + async def set_secret_expirations( + self, secret_uuids: list[str], pilot_secret_expiration_dates: list[DateTime] + ): + """Bulk set expiration dates to secrets. + + Raises: + - SecretNotFoundError if one of the secret_uuid is not associated with a secret. + - NotImplementedError if a integrity error is not caught. + - + + """ + values = [ + {"b_SecretUUID": secret_uuid, "SecretExpirationDate": pilot_secret} + for secret_uuid, pilot_secret in zip( + secret_uuids, pilot_secret_expiration_dates + ) + ] + + # Prepare the update statement + stmt = ( + update(PilotSecrets) + .where(PilotSecrets.secret_uuid == bindparam("b_SecretUUID")) + .values({"SecretExpirationDate": bindparam("SecretExpirationDate")}) + ) + + try: + await self.conn.execute(stmt, values) + except IntegrityError as e: + if "foreign key" in str(e.orig).lower(): + raise SecretNotFoundError( + detail="at least one of these secrets does not exist", + ) from e + raise NotImplementedError(f"This error is not caught: {str(e.orig)}") from e + + async def search_secrets( + self, + parameters: list[str] | None, + search: list[SearchSpec], + sorts: list[SortSpec], + *, + distinct: bool = False, + per_page: int = 100, + page: int | None = None, + ) -> tuple[int, list[dict[Any, Any]]]: + """Search for secrets in the database.""" + return await self._search( + table=PilotSecrets, + parameters=parameters, + search=search, + sorts=sorts, + distinct=distinct, + per_page=per_page, + page=page, + ) diff --git a/diracx-db/src/diracx/db/sql/auth/schema.py b/diracx-db/src/diracx/db/sql/auth/schema.py index 95a17f49c..f7e6db69a 100644 --- a/diracx-db/src/diracx/db/sql/auth/schema.py +++ b/diracx-db/src/diracx/db/sql/auth/schema.py @@ -3,9 +3,13 @@ from enum import Enum, auto from sqlalchemy import ( + BINARY, JSON, + DateTime, Index, + SmallInteger, String, + UniqueConstraint, Uuid, ) from sqlalchemy.orm import declarative_base @@ -99,3 +103,28 @@ class RefreshTokens(Base): sub = Column("Sub", String(256), index=True) __table_args__ = (Index("index_status_sub", status, sub),) + + +class PilotSecrets(Base): + __tablename__ = "PilotSecrets" + + secret_uuid = Column("SecretUUID", Uuid(as_uuid=False), primary_key=True) + + hashed_secret = Column("HashedSecret", BINARY(32)) + # Global count + # Null: Infinite use + secret_remaining_use_count = NullColumn( + "SecretRemainingUseCount", SmallInteger, default=1 + ) + secret_expiration_date = NullColumn("SecretExpirationDate", DateTime(timezone=True)) + # To authorize only specific pilots to access a secret + # The constraint format follows diracx.code.models.PilotSecretConstraints + secret_constraints = NullColumn("SecretConstraints", JSON) + + # If a date is set, then it used a secret (acts also like a "PilotUsedSecret" field) + pilot_secret_use_date = NullColumn("PilotSecretUseDate", DateTime(timezone=True)) + + __table_args__ = ( + UniqueConstraint("HashedSecret", name="uq_hashed_secret"), + Index("HashedSecret", "HashedSecret"), + ) diff --git a/diracx-db/src/diracx/db/sql/pilots/db.py b/diracx-db/src/diracx/db/sql/pilots/db.py index cb6a3fbbe..0bfb32e07 100644 --- a/diracx-db/src/diracx/db/sql/pilots/db.py +++ b/diracx-db/src/diracx/db/sql/pilots/db.py @@ -3,24 +3,20 @@ from datetime import datetime, timezone from typing import Any -from sqlalchemy import DateTime, bindparam +from sqlalchemy import bindparam from sqlalchemy.exc import IntegrityError from sqlalchemy.sql import delete, insert, update -from uuid_utils import uuid7 from diracx.core.exceptions import ( PilotAlreadyAssociatedWithJobError, PilotNotFoundError, - SecretNotFoundError, ) from diracx.core.models import ( PilotFieldsMapping, - PilotSecretConstraints, PilotStatus, SearchSpec, SortSpec, ) -from diracx.db.sql.utils.functions import utcnow from ..utils import ( BaseSQLDB, @@ -30,7 +26,6 @@ PilotAgents, PilotAgentsDBBase, PilotOutput, - PilotSecrets, ) @@ -123,31 +118,6 @@ async def add_jobs_to_pilot(self, job_to_pilot_mapping: list[dict[str, Any]]): "Engine Specific error not caught" + str(e) ) from e - async def insert_unique_secrets( - self, - hashed_secrets: list[bytes], - secret_global_use_count_max: int | None = 1, - secret_constraints: dict[bytes, PilotSecretConstraints] = {}, - ): - """Bulk insert secrets. - - Raises: - - NotImplementedError if we have an IntegrityError not caught - - """ - values = [ - { - "SecretUUID": str(uuid7()), - "SecretRemainingUseCount": secret_global_use_count_max, - "HashedSecret": hashed_secret, - "SecretConstraints": secret_constraints.get(hashed_secret, {}), - } - for hashed_secret in hashed_secrets - ] - - stmt = insert(PilotSecrets).values(values) - await self.conn.execute(stmt) - # ----------------------------- Delete Functions ----------------------------- async def delete_pilots(self, pilot_ids: list[int]): @@ -170,23 +140,6 @@ async def delete_pilot_logs(self, pilot_ids: list[int]): await self.conn.execute(stmt) - async def delete_secrets(self, secret_uuids: list[str]): - """Bulk delete secrets. - - Raises SecretNotFoundError if one of the secret was not found. - """ - stmt = delete(PilotSecrets).where(PilotSecrets.secret_uuid.in_(secret_uuids)) - - res = await self.conn.execute(stmt) - - if res.rowcount != len(secret_uuids): - raise SecretNotFoundError( - "At least one of the secret has not been deleted." - ) - - # We NEED to commit here, because we will raise an error after this function - await self.conn.commit() - # ----------------------------- Update Functions ----------------------------- async def update_pilot_fields( @@ -241,91 +194,6 @@ async def update_pilot_fields( if res.rowcount != len(pilot_stamps_to_fields_mapping): raise PilotNotFoundError("at least one of the given pilot does not exist.") - async def update_pilot_secret_use_time(self, secret_uuid: str) -> None: - """Updates when a pilot uses a secret. - - Raises PilotNotFoundError if the pilot does not exist - - """ - # Prepare the update statement - stmt = ( - update(PilotSecrets) - .values( - pilot_secret_use_date=utcnow(), - secret_remaining_use_count=PilotSecrets.secret_remaining_use_count - 1, - ) - .where(PilotSecrets.secret_uuid == secret_uuid) - ) - - # Execute the update using the connection - res = await self.conn.execute(stmt) - - if res.rowcount == 0: - raise SecretNotFoundError("Unknown secret") - - async def update_pilot_secrets_constraints( - self, hashed_secrets_to_pilot_stamps_mapping: list[dict[str, Any]] - ): - """Bulk associate pilots with secrets by updating theirs constraints. - - Important: We have to provide the updated constraints. - - Raises: - - PilotNotFoundError if one of the pilot does not exist - - NotImplementedError if at least of the pilot - - """ - # Better to give as a parameter pilot to secret associations, rather than associating here. - - stmt = ( - update(PilotSecrets) - .where(PilotSecrets.hashed_secret == bindparam("PilotHashedSecret")) - .values({"SecretConstraints": bindparam("PilotSecretConstraints")}) - ) - - try: - await self.conn.execute(stmt, hashed_secrets_to_pilot_stamps_mapping) - except IntegrityError as e: - if "foreign key" in str(e.orig).lower(): - raise SecretNotFoundError( - detail="at least one of these secrets does not exist", - ) from e - raise NotImplementedError(f"This error is not caught: {str(e.orig)}") from e - - async def set_secret_expirations( - self, secret_uuids: list[str], pilot_secret_expiration_dates: list[DateTime] - ): - """Bulk set expiration dates to secrets. - - Raises: - - SecretNotFoundError if one of the secret_uuid is not associated with a secret. - - NotImplementedError if a integrity error is not caught. - - - - """ - values = [ - {"b_SecretUUID": secret_uuid, "SecretExpirationDate": pilot_secret} - for secret_uuid, pilot_secret in zip( - secret_uuids, pilot_secret_expiration_dates - ) - ] - - # Prepare the update statement - stmt = ( - update(PilotSecrets) - .where(PilotSecrets.secret_uuid == bindparam("b_SecretUUID")) - .values({"SecretExpirationDate": bindparam("SecretExpirationDate")}) - ) - - try: - await self.conn.execute(stmt, values) - except IntegrityError as e: - if "foreign key" in str(e.orig).lower(): - raise SecretNotFoundError( - detail="at least one of these secrets does not exist", - ) from e - raise NotImplementedError(f"This error is not caught: {str(e.orig)}") from e - # ----------------------------- Search Functions ----------------------------- async def search_pilots( @@ -370,27 +238,6 @@ async def search_pilot_to_job_mapping( page=page, ) - async def search_secrets( - self, - parameters: list[str] | None, - search: list[SearchSpec], - sorts: list[SortSpec], - *, - distinct: bool = False, - per_page: int = 100, - page: int | None = None, - ) -> tuple[int, list[dict[Any, Any]]]: - """Search for secrets in the database.""" - return await self._search( - table=PilotSecrets, - parameters=parameters, - search=search, - sorts=sorts, - distinct=distinct, - per_page=per_page, - page=page, - ) - async def pilot_summary( self, group_by: list[str], search: list[SearchSpec] ) -> list[dict[str, str | int]]: diff --git a/diracx-db/src/diracx/db/sql/pilots/schema.py b/diracx-db/src/diracx/db/sql/pilots/schema.py index 4ad4c9cb3..af087f1f8 100644 --- a/diracx-db/src/diracx/db/sql/pilots/schema.py +++ b/diracx-db/src/diracx/db/sql/pilots/schema.py @@ -1,17 +1,12 @@ from __future__ import annotations from sqlalchemy import ( - BINARY, - JSON, DateTime, Double, Index, Integer, - SmallInteger, String, Text, - UniqueConstraint, - Uuid, ) from sqlalchemy.orm import declarative_base @@ -66,28 +61,3 @@ class PilotOutput(PilotAgentsDBBase): pilot_id = Column("PilotID", Integer, primary_key=True) std_output = Column("StdOutput", Text) std_error = Column("StdError", Text) - - -class PilotSecrets(PilotAgentsDBBase): - __tablename__ = "PilotSecrets" - - secret_uuid = Column("SecretUUID", Uuid(as_uuid=False), primary_key=True) - - hashed_secret = Column("HashedSecret", BINARY(32)) - # Global count - # Null: Infinite use - secret_remaining_use_count = NullColumn( - "SecretRemainingUseCount", SmallInteger, default=1 - ) - secret_expiration_date = NullColumn("SecretExpirationDate", DateTime(timezone=True)) - # To authorize only specific pilots to access a secret - # The constraint format follows diracx.code.models.PilotSecretConstraints - secret_constraints = NullColumn("SecretConstraints", JSON) - - # If a date is set, then it used a secret (acts also like a "PilotUsedSecret" field) - pilot_secret_use_date = NullColumn("PilotSecretUseDate", DateTime(timezone=True)) - - __table_args__ = ( - UniqueConstraint("HashedSecret", name="uq_hashed_secret"), - Index("HashedSecret", "HashedSecret"), - ) diff --git a/diracx-db/tests/auth/test_authorization_flow.py b/diracx-db/tests/auth/test_authorization_flow.py index 56a3d332c..0336760a6 100644 --- a/diracx-db/tests/auth/test_authorization_flow.py +++ b/diracx-db/tests/auth/test_authorization_flow.py @@ -4,7 +4,7 @@ from sqlalchemy.exc import NoResultFound from diracx.core.exceptions import AuthorizationError -from diracx.db.sql.auth.db import AuthDB +from diracx.db.sql import AuthDB MAX_VALIDITY = 2 EXPIRED = 0 diff --git a/diracx-db/tests/auth/test_device_flow.py b/diracx-db/tests/auth/test_device_flow.py index 112e77898..ac6cd0825 100644 --- a/diracx-db/tests/auth/test_device_flow.py +++ b/diracx-db/tests/auth/test_device_flow.py @@ -7,7 +7,7 @@ from sqlalchemy.exc import NoResultFound from diracx.core.exceptions import AuthorizationError -from diracx.db.sql.auth.db import AuthDB +from diracx.db.sql import AuthDB from diracx.db.sql.auth.schema import USER_CODE_LENGTH from diracx.db.sql.utils.functions import substract_date diff --git a/diracx-db/tests/auth/test_refresh_token.py b/diracx-db/tests/auth/test_refresh_token.py index 88d8c7017..0fd8b3d21 100644 --- a/diracx-db/tests/auth/test_refresh_token.py +++ b/diracx-db/tests/auth/test_refresh_token.py @@ -3,7 +3,7 @@ import pytest from uuid_utils import UUID, uuid7 -from diracx.db.sql.auth.db import AuthDB +from diracx.db.sql import AuthDB from diracx.db.sql.auth.schema import RefreshTokenStatus diff --git a/diracx-db/tests/pilots/test_pilot_auth.py b/diracx-db/tests/pilots/test_pilot_auth.py deleted file mode 100644 index 99ea0c58e..000000000 --- a/diracx-db/tests/pilots/test_pilot_auth.py +++ /dev/null @@ -1,166 +0,0 @@ -from __future__ import annotations - -from datetime import timedelta -from random import shuffle -from typing import AsyncGenerator, Generator - -import freezegun -import pytest -import sqlalchemy - -from diracx.core.exceptions import ( - BadPilotCredentialsError, - PilotNotFoundError, - SecretHasExpiredError, - SecretNotFoundError, -) -from diracx.db.sql.pilots.db import PilotAgentsDB -from diracx.db.sql.utils.functions import raw_hash -from diracx.testing.time import mock_sqlite_time - -from .utils import ( - add_secrets_and_time, # noqa: F401 - add_stamps, # noqa: F401 - verify_pilot_secret, -) - - -@pytest.fixture -async def pilot_db() -> AsyncGenerator[PilotAgentsDB, None]: - agents_db = PilotAgentsDB("sqlite+aiosqlite:///:memory:") - async with agents_db.engine_context(): - sqlalchemy.event.listen( - agents_db.engine.sync_engine, "connect", mock_sqlite_time - ) - async with agents_db.engine.begin() as conn: - await conn.run_sync(agents_db.metadata.create_all) - yield agents_db - - -@pytest.fixture() -def frozen_time() -> Generator[freezegun.FreezeGun, None]: - with freezegun.freeze_time("2012-01-14") as ft: - yield ft - - -@pytest.mark.parametrize("secret_duration_sec", [10]) -@pytest.mark.asyncio -async def test_create_pilot_and_verify_secret( - pilot_db: PilotAgentsDB, - add_secrets_and_time, # noqa: F811 - frozen_time: freezegun.FreezeGun, -): - # Add pilots - result = add_secrets_and_time - stamps = result["stamps"] - secrets = result["secrets"] - - pairs = list(zip(stamps, secrets)) - # Shuffle it to prove that credentials are well associated - shuffle(pairs) - - async with pilot_db as pilot_db: - for stamp, secret in pairs: - await verify_pilot_secret( - pilot_db=pilot_db, - pilot_stamp=stamp, - hashed_secret=raw_hash(secret), - frozen_time=frozen_time, - ) - - with pytest.raises(SecretNotFoundError): - await verify_pilot_secret( - pilot_db=pilot_db, - pilot_stamp=stamps[0], - hashed_secret=raw_hash("I love stawberries :)"), - frozen_time=frozen_time, - ) - - with pytest.raises(PilotNotFoundError): - await verify_pilot_secret( - pilot_db=pilot_db, - pilot_stamp="I am a spider", - hashed_secret=raw_hash(secrets[0]), - frozen_time=frozen_time, - ) - - -@pytest.mark.parametrize("secret_duration_sec", [1]) -@pytest.mark.asyncio -async def test_create_pilot_and_verify_secret_with_delay( - pilot_db: PilotAgentsDB, - add_secrets_and_time, # noqa: F811 - frozen_time: freezegun.FreezeGun, -): - # Add pilots - result = add_secrets_and_time - stamps = result["stamps"] - secrets = result["secrets"] - - # Move forward few minutes - frozen_time.tick(delta=timedelta(minutes=5)) - - async with pilot_db as pilot_db: - with pytest.raises(SecretHasExpiredError): - await verify_pilot_secret( - pilot_db=pilot_db, - pilot_stamp=stamps[0], - hashed_secret=raw_hash(secrets[0]), - frozen_time=frozen_time, - ) - - -@pytest.mark.parametrize("secret_duration_sec", [10]) -@pytest.mark.asyncio -async def test_create_pilot_and_verify_secret_too_much_secret_use( - pilot_db: PilotAgentsDB, - add_secrets_and_time, # noqa: F811 - frozen_time: freezegun.FreezeGun, -): - # Add pilots - result = add_secrets_and_time - stamps = result["stamps"] - secrets = result["secrets"] - - # First login, should work - async with pilot_db as pilot_db: - await verify_pilot_secret( - pilot_db=pilot_db, - pilot_stamp=stamps[0], - hashed_secret=raw_hash(secrets[0]), - frozen_time=frozen_time, - ) - - # Second login, should not work because maxed out at 1 try - # If the foreign key works, we should have "SecretNotFoundError" - with pytest.raises(SecretNotFoundError): - await verify_pilot_secret( - pilot_db=pilot_db, - pilot_stamp=stamps[0], - hashed_secret=raw_hash(secrets[0]), - frozen_time=frozen_time, - ) - - -@pytest.mark.parametrize("secret_duration_sec", [10]) -@pytest.mark.asyncio -async def test_create_pilot_and_login_with_bad_secret( - pilot_db: PilotAgentsDB, - add_secrets_and_time, # noqa: F811 - frozen_time: freezegun.FreezeGun, -): - # Add pilots - result = add_secrets_and_time - stamps = result["stamps"] - secrets = result["secrets"] - - async with pilot_db as pilot_db: - # Pilot1 will try to login with every other pilots's secret - for secret in secrets[1:]: - with pytest.raises(BadPilotCredentialsError): - await verify_pilot_secret( - pilot_db=pilot_db, - pilot_stamp=stamps[0], - hashed_secret=raw_hash(secret), - frozen_time=frozen_time, - ) diff --git a/diracx-db/tests/pilots/test_pilot_management.py b/diracx-db/tests/pilots/test_pilot_management.py index 1e7397b39..41198f389 100644 --- a/diracx-db/tests/pilots/test_pilot_management.py +++ b/diracx-db/tests/pilots/test_pilot_management.py @@ -11,7 +11,7 @@ PilotFieldsMapping, PilotStatus, ) -from diracx.db.sql.pilots.db import PilotAgentsDB +from diracx.db.sql import PilotAgentsDB from .utils import ( add_stamps, # noqa: F401 diff --git a/diracx-db/tests/pilots/test_query.py b/diracx-db/tests/pilots/test_query.py index be80f0179..605264b15 100644 --- a/diracx-db/tests/pilots/test_query.py +++ b/diracx-db/tests/pilots/test_query.py @@ -13,7 +13,7 @@ VectorSearchOperator, VectorSearchSpec, ) -from diracx.db.sql.pilots.db import PilotAgentsDB +from diracx.db.sql import PilotAgentsDB MAIN_VO = "lhcb" N = 100 diff --git a/diracx-db/tests/pilots/utils.py b/diracx-db/tests/pilots/utils.py index df73fd3ec..07b544c7c 100644 --- a/diracx-db/tests/pilots/utils.py +++ b/diracx-db/tests/pilots/utils.py @@ -1,28 +1,19 @@ from __future__ import annotations -from datetime import datetime, timedelta, timezone +from datetime import datetime, timezone from typing import Any -import freezegun import pytest from sqlalchemy import update -from diracx.core.exceptions import ( - BadPilotCredentialsError, - PilotNotFoundError, - SecretHasExpiredError, - SecretNotFoundError, -) from diracx.core.models import ( - PilotSecretConstraints, ScalarSearchOperator, ScalarSearchSpec, VectorSearchOperator, VectorSearchSpec, ) -from diracx.db.sql.pilots.db import PilotAgentsDB +from diracx.db.sql import PilotAgentsDB from diracx.db.sql.pilots.schema import PilotAgents -from diracx.db.sql.utils.functions import raw_hash MAIN_VO = "lhcb" N = 100 @@ -72,56 +63,6 @@ async def get_pilot_jobs_ids_by_pilot_id( return [job["JobID"] for job in jobs] -async def get_secrets_by_hashed_secrets( - pilot_db: PilotAgentsDB, hashed_secrets: list[bytes], parameters: list[str] = [] -) -> list[dict[Any, Any]]: - _, secrets = await pilot_db.search_secrets( - parameters=parameters, - search=[ - VectorSearchSpec( - parameter="HashedSecret", - operator=VectorSearchOperator.IN, - values=hashed_secrets, - ) - ], - sorts=[], - distinct=True, - per_page=1000, - ) - - return secrets - - -async def get_secrets_by_uuid( - pilot_db: PilotAgentsDB, secret_uuids: list[str], parameters: list[str] = [] -) -> list[dict[Any, Any]]: - parameters.append("SecretUUID") # To avoid bug later on `found_keys = ...` - - _, secrets = await pilot_db.search_secrets( - parameters=parameters, - search=[ - VectorSearchSpec( - parameter="SecretUUID", - operator=VectorSearchOperator.IN, - values=secret_uuids, - ) - ], - sorts=[], - distinct=True, - per_page=1000, - ) - - # Custom handling, to see which secret_uuid does not exist - # TODO: Add missing in the error - found_keys = {row["SecretUUID"] for row in secrets} - missing = set(secret_uuids) - found_keys - - if missing: - raise SecretNotFoundError(detail=str(missing)) - - return secrets - - # ------------ Creating data ------------ @@ -210,113 +151,3 @@ async def create_old_pilots_environment(pilot_db, create_timed_pilots): await get_pilots_by_stamp(pilot_db, [non_aborted_very_old[0]["PilotStamp"]]) return non_aborted_recent, aborted_recent, non_aborted_very_old, aborted_very_old - - -@pytest.fixture -async def add_secrets_and_time( - pilot_db, add_stamps, secret_duration_sec, frozen_time: freezegun.FreezeGun -): - # Retrieve the stamps from the add_stamps fixture - stamps = [pilot["PilotStamp"] for pilot in await add_stamps()] - - # Add a VO restriction as well as association with a specific pilot - secrets = [f"AW0nd3rfulS3cr3t_{str(i)}" for i in range(len(stamps))] - hashed_secrets = [raw_hash(secret) for secret in secrets] - constraints = { - hashed_secret: PilotSecretConstraints(VOs=[MAIN_VO], PilotStamps=[stamp]) - for hashed_secret, stamp in zip(hashed_secrets, stamps) - } - - async with pilot_db as pilot_db: - # Add creds - await pilot_db.insert_unique_secrets( - hashed_secrets=hashed_secrets, secret_constraints=constraints - ) - - # Associate with pilot - secrets_obj = await get_secrets_by_hashed_secrets(pilot_db, hashed_secrets) - - assert len(secrets_obj) == len(hashed_secrets) == len(stamps) - - # extract_timestamp_from_uuid7(secret_obj["SecretUUID"]) does not work here - # See #548 - expiration_date = [ - datetime.now(timezone.utc) + timedelta(seconds=secret_duration_sec) - for secret_obj in secrets_obj - ] - - await pilot_db.set_secret_expirations( - secret_uuids=[secret_obj["SecretUUID"] for secret_obj in secrets_obj], - pilot_secret_expiration_dates=expiration_date, - ) - - # Return both non-hashed secrets and stamps - return {"stamps": stamps, "secrets": secrets} - - -# ------------ Verifying data ------------ - - -async def verify_pilot_secret( - pilot_stamp: str, - pilot_db: PilotAgentsDB, - hashed_secret: bytes, - frozen_time: freezegun.FreezeGun, -) -> None: - # 1. Get the pilot - pilots = await get_pilots_by_stamp( - pilot_db=pilot_db, - pilot_stamps=[pilot_stamp], - parameters=["VO", "PilotStamp"], - ) - if len(pilots) == 0: - raise PilotNotFoundError() - pilot = dict(pilots[0]) - - # 2. Get the secret itself - secrets = await get_secrets_by_hashed_secrets( - pilot_db=pilot_db, hashed_secrets=[hashed_secret] - ) - if len(secrets) == 0: - raise SecretNotFoundError(str(hashed_secret)) - secret = secrets[0] - secret_uuid = secret["SecretUUID"] - secret_constraints = PilotSecretConstraints(**secret["SecretConstraints"]) - - # 3. Check the constraints - await check_pilot_constraints(pilot=pilot, secret_constraints=secret_constraints) - - # 4. Check if the secret is expired - now = datetime.now(tz=timezone.utc) - # Convert the timezone, TODO: Change with #454: https://github.com/DIRACGrid/diracx/pull/454 - expiration = secret["SecretExpirationDate"].replace(tzinfo=timezone.utc) - if expiration < now: - await pilot_db.delete_secrets([secret_uuid]) - - raise SecretHasExpiredError( - f"expiration_date {secret['SecretExpirationDate']}", - ) - - # 5. Now the pilot is authorized, change when the pilot used the secret. - await pilot_db.update_pilot_secret_use_time( - secret_uuid=secret_uuid, - ) - - # 6. Delete the secret if its count attained the secret_global_use_count_max - if secret["SecretRemainingUseCount"]: - # If we use it another time, SecretRemainingUseCount will be equal to 0 so we can delete it - if secret["SecretRemainingUseCount"] == 1: - await pilot_db.delete_secrets([secret_uuid]) - - -async def check_pilot_constraints( - pilot: dict[str, Any], secret_constraints: PilotSecretConstraints -): - key_map = {"VOs": "VO", "PilotStamps": "PilotStamp", "Sites": "Site"} - - for constraint_key, pilot_key in key_map.items(): - allowed_values = secret_constraints.get(constraint_key) - if allowed_values: - pilot_value = pilot.get(pilot_key) - if pilot_value is None or pilot_value not in allowed_values: - raise BadPilotCredentialsError() diff --git a/diracx-logic/src/diracx/logic/pilots/auth.py b/diracx-logic/src/diracx/logic/pilots/auth.py index ec9ae2188..385d40c62 100644 --- a/diracx-logic/src/diracx/logic/pilots/auth.py +++ b/diracx-logic/src/diracx/logic/pilots/auth.py @@ -37,7 +37,7 @@ async def create_raw_secrets( n: int, - pilot_db: PilotAgentsDB, + auth_db: AuthDB, settings: AuthSettings, secret_constraint: PilotSecretConstraints, pilot_secret_use_count_max: int | None = 1, @@ -54,14 +54,14 @@ async def create_raw_secrets( } # Insert secrets - await pilot_db.insert_unique_secrets( + await auth_db.insert_unique_secrets( hashed_secrets=hashed_secrets, secret_global_use_count_max=pilot_secret_use_count_max, secret_constraints=secret_constraints, ) secrets_added = await get_secrets_by_hashed_secrets( - pilot_db=pilot_db, + auth_db=auth_db, hashed_secrets=hashed_secrets, parameters=["SecretUUID"], # For efficiency ) @@ -81,7 +81,7 @@ async def create_raw_secrets( secret_uuids = [secret["SecretUUID"] for secret in secrets_added] # Helps compatibility between sql engines - await pilot_db.set_secret_expirations( + await auth_db.set_secret_expirations( secret_uuids=secret_uuids, pilot_secret_expiration_dates=expiration_dates, # type: ignore ) @@ -95,7 +95,7 @@ async def create_raw_secrets( async def create_secrets( n: int, - pilot_db: PilotAgentsDB, + auth_db: AuthDB, settings: AuthSettings, secret_constraint: PilotSecretConstraints, pilot_secret_use_count_max: int | None = 1, @@ -103,7 +103,7 @@ async def create_secrets( ) -> list[PilotSecretsInfo]: pilot_secrets, expiration_dates_timestamps = await create_raw_secrets( n=n, - pilot_db=pilot_db, + auth_db=auth_db, settings=settings, pilot_secret_use_count_max=pilot_secret_use_count_max, expiration_minutes=expiration_minutes, @@ -120,7 +120,7 @@ async def create_secrets( async def update_secrets_constraints( - pilot_db: PilotAgentsDB, + auth_db: AuthDB, secrets_to_constraints_dict: dict[str, PilotSecretConstraints], ): # 1. Create a mapping that uses hashed_secret @@ -137,7 +137,7 @@ async def update_secrets_constraints( # 2. Get the secret ids to later associate them with pilots # It also verifies that all secrets exist secrets_obj = await get_secrets_by_hashed_secrets( - pilot_db=pilot_db, + auth_db=auth_db, hashed_secrets=list(hashed_secrets_to_pilot_stamps_dict.keys()), parameters=["SecretConstraints"], # For efficiency, we don't need more info ) @@ -164,7 +164,7 @@ async def update_secrets_constraints( } ) - await pilot_db.update_pilot_secrets_constraints( + await auth_db.update_pilot_secrets_constraints( hashed_secrets_to_pilot_stamps_mapping ) @@ -201,7 +201,7 @@ async def verify_pilot_credentials( # 2. Get the secret itself secrets = await get_secrets_by_hashed_secrets( - pilot_db=pilot_db, hashed_secrets=[hashed_secret] + auth_db=auth_db, hashed_secrets=[hashed_secret] ) secret = secrets[0] secret_uuid = secret["SecretUUID"] @@ -215,14 +215,14 @@ async def verify_pilot_credentials( # Convert the timezone, TODO: Change with #454: https://github.com/DIRACGrid/diracx/pull/454 expiration = secret["SecretExpirationDate"].replace(tzinfo=timezone.utc) if expiration < now: - await pilot_db.delete_secrets([secret_uuid]) + await auth_db.delete_secrets([secret_uuid]) raise SecretHasExpiredError( detail=f"expiration_date{secret['SecretExpirationDate']}", ) # 5. Now the pilot is authorized, change when the pilot used the secret. - await pilot_db.update_pilot_secret_use_time( + await auth_db.update_pilot_secret_use_time( secret_uuid=secret_uuid, ) @@ -230,7 +230,7 @@ async def verify_pilot_credentials( if secret["SecretRemainingUseCount"]: # If we use it another time, SecretRemainingUseCount will be equal to 0 so we can delete it if secret["SecretRemainingUseCount"] == 1: - await pilot_db.delete_secrets([secret_uuid]) + await auth_db.delete_secrets([secret_uuid]) # Get token, and serialize access_token_payload, refresh_token_payload = await generate_pilot_tokens( diff --git a/diracx-logic/src/diracx/logic/pilots/management.py b/diracx-logic/src/diracx/logic/pilots/management.py index a74012b74..ab518526e 100644 --- a/diracx-logic/src/diracx/logic/pilots/management.py +++ b/diracx-logic/src/diracx/logic/pilots/management.py @@ -9,7 +9,7 @@ PilotSecretConstraints, ) from diracx.core.settings import AuthSettings -from diracx.db.sql import PilotAgentsDB +from diracx.db.sql import AuthDB, PilotAgentsDB from .auth import create_raw_secrets, update_secrets_constraints from .query import ( @@ -22,6 +22,7 @@ async def register_new_pilots( pilot_db: PilotAgentsDB, + auth_db: AuthDB, pilot_stamps: list[str], vo: str, grid_type: str, @@ -62,7 +63,7 @@ async def register_new_pilots( pilot_secrets, expiration_dates_timestamps = await create_raw_secrets( n=len(pilot_stamps), - pilot_db=pilot_db, + auth_db=auth_db, settings=settings, pilot_secret_use_count_max=pilot_secret_use_count_max, secret_constraint=PilotSecretConstraints(VOs=[vo]), @@ -74,7 +75,7 @@ async def register_new_pilots( } await update_secrets_constraints( - pilot_db=pilot_db, secrets_to_constraints_dict=constraints + auth_db=auth_db, secrets_to_constraints_dict=constraints ) return [ diff --git a/diracx-logic/src/diracx/logic/pilots/query.py b/diracx-logic/src/diracx/logic/pilots/query.py index 7be9bd802..43afd8653 100644 --- a/diracx-logic/src/diracx/logic/pilots/query.py +++ b/diracx-logic/src/diracx/logic/pilots/query.py @@ -14,7 +14,7 @@ VectorSearchOperator, VectorSearchSpec, ) -from diracx.db.sql import PilotAgentsDB +from diracx.db.sql import AuthDB, PilotAgentsDB MAX_PER_PAGE = 10000 @@ -192,12 +192,12 @@ async def summary(pilot_db: PilotAgentsDB, body: SummaryParams, vo: str): async def get_secrets_by_hashed_secrets( - pilot_db: PilotAgentsDB, hashed_secrets: list[bytes], parameters: list[str] = [] + auth_db: AuthDB, hashed_secrets: list[bytes], parameters: list[str] = [] ) -> list[dict[Any, Any]]: if parameters: parameters.append("HashedSecret") - _, secrets = await pilot_db.search_secrets( + _, secrets = await auth_db.search_secrets( parameters=parameters, search=[ VectorSearchSpec( @@ -222,12 +222,12 @@ async def get_secrets_by_hashed_secrets( async def get_secrets_by_uuid( - pilot_db: PilotAgentsDB, secret_uuids: list[str], parameters: list[str] = [] + auth_db: AuthDB, secret_uuids: list[str], parameters: list[str] = [] ) -> list[dict[Any, Any]]: if parameters: parameters.append("SecretUUID") # To avoid bug later on `found_keys = ...` - _, secrets = await pilot_db.search_secrets( + _, secrets = await auth_db.search_secrets( parameters=parameters, search=[ VectorSearchSpec( diff --git a/diracx-routers/src/diracx/routers/pilots/access_policies.py b/diracx-routers/src/diracx/routers/pilots/access_policies.py index 61a324f79..15b14673f 100644 --- a/diracx-routers/src/diracx/routers/pilots/access_policies.py +++ b/diracx-routers/src/diracx/routers/pilots/access_policies.py @@ -8,8 +8,8 @@ from diracx.core.models import VectorSearchOperator, VectorSearchSpec from diracx.core.properties import GENERIC_PILOT, SERVICE_ADMINISTRATOR +from diracx.db.sql import PilotAgentsDB from diracx.db.sql.job.db import JobDB -from diracx.db.sql.pilots.db import PilotAgentsDB from diracx.logic.pilots.query import get_pilots_by_stamp from diracx.routers.access_policies import BaseAccessPolicy from diracx.routers.utils.users import AuthorizedUserInfo diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py index 2304f166b..d02f1f918 100644 --- a/diracx-routers/src/diracx/routers/pilots/management.py +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -33,7 +33,7 @@ from diracx.logic.pilots.query import get_pilot_ids_by_job_id from diracx.routers.utils.users import AuthorizedUserInfo, verify_dirac_access_token -from ..dependencies import AuthSettings, JobDB, PilotAgentsDB +from ..dependencies import AuthDB, AuthSettings, JobDB, PilotAgentsDB from ..fastapi_classes import DiracxRouter from .access_policies import ( ActionType, @@ -46,6 +46,7 @@ @router.post("/") async def add_pilot_stamps( pilot_db: PilotAgentsDB, + auth_db: AuthDB, pilot_stamps: Annotated[ list[str], Body(description="List of the pilot stamps we want to add to the db."), @@ -103,6 +104,7 @@ async def add_pilot_stamps( try: return await register_new_pilots( + auth_db=auth_db, pilot_db=pilot_db, pilot_stamps=pilot_stamps, vo=user_info.vo, @@ -196,7 +198,7 @@ async def create_pilot_secrets( ], vo: Annotated[str, Body(description="Only VO that can access a secret.")], check_permissions: CheckPilotManagementPolicyCallable, - pilot_db: PilotAgentsDB, + auth_db: AuthDB, settings: AuthSettings, ) -> list[PilotSecretsInfo]: """Endpoint to create secrets.""" @@ -215,7 +217,7 @@ async def create_pilot_secrets( return await create_secrets( n=n, - pilot_db=pilot_db, + auth_db=auth_db, settings=settings, secret_constraint=PilotSecretConstraints(VOs=[vo]), pilot_secret_use_count_max=pilot_secret_use_count_max, @@ -229,7 +231,7 @@ async def update_secrets_constraints( dict[str, PilotSecretConstraints], Body(description="Mapping between secrets and pilots.", embed=False), ], - pilot_db: PilotAgentsDB, + auth_db: AuthDB, check_permissions: CheckPilotManagementPolicyCallable, ): """Endpoint to associate pilots with secrets.""" @@ -244,7 +246,7 @@ async def update_secrets_constraints( try: await update_secrets_constraints_bl( - pilot_db=pilot_db, + auth_db=auth_db, secrets_to_constraints_dict=secrets_to_constraints_dict, ) except SecretNotFoundError as e: diff --git a/diracx-routers/tests/pilots/test_pilot_auth.py b/diracx-routers/tests/pilots/test_pilot_auth.py index 618788b18..11c74b999 100644 --- a/diracx-routers/tests/pilots/test_pilot_auth.py +++ b/diracx-routers/tests/pilots/test_pilot_auth.py @@ -8,7 +8,7 @@ from pytest_httpx import HTTPXMock from diracx.core.models import PilotSecretConstraints -from diracx.db.sql.pilots.db import PilotAgentsDB +from diracx.db.sql import AuthDB, PilotAgentsDB from diracx.db.sql.utils.functions import raw_hash from diracx.logic.pilots.query import ( get_pilots_by_stamp, @@ -70,14 +70,14 @@ async def add_stamps(normal_test_client): pilots = await get_pilots_by_stamp(db, stamps) - return pilots + return pilots @pytest.fixture async def add_secrets_and_time(normal_test_client, add_stamps, secret_duration_sec): - db = normal_test_client.app.dependency_overrides[PilotAgentsDB.transaction].args[0] + db = normal_test_client.app.dependency_overrides[AuthDB.transaction].args[0] - async with db as pilot_db: + async with db as auth_db: # Retrieve the stamps from the add_stamps fixture stamps = [pilot["PilotStamp"] for pilot in add_stamps] @@ -90,7 +90,7 @@ async def add_secrets_and_time(normal_test_client, add_stamps, secret_duration_s } # Add creds - await pilot_db.insert_unique_secrets( + await auth_db.insert_unique_secrets( hashed_secrets=hashed_secrets, secret_constraints=constraints ) @@ -106,7 +106,7 @@ async def add_secrets_and_time(normal_test_client, add_stamps, secret_duration_s for secret_obj in secrets_obj ] - await pilot_db.set_secret_expirations( + await auth_db.set_secret_expirations( secret_uuids=[secret_obj["SecretUUID"] for secret_obj in secrets_obj], pilot_secret_expiration_dates=expiration_date, ) diff --git a/diracx-routers/tests/pilots/test_query.py b/diracx-routers/tests/pilots/test_query.py index c6d5cedb4..e189415ba 100644 --- a/diracx-routers/tests/pilots/test_query.py +++ b/diracx-routers/tests/pilots/test_query.py @@ -23,6 +23,7 @@ "ConfigSource", "DevelopmentSettings", "PilotAgentsDB", + "AuthDB", "PilotManagementAccessPolicy", ] ) From b0cbdbcd2e2c2594762a373107638038a878c810 Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Wed, 6 Aug 2025 12:50:05 +0200 Subject: [PATCH 08/11] docs: Add doc to recursive merge, and uuid_utils to pyproject --- diracx-core/pyproject.toml | 1 + diracx-core/src/diracx/core/utils.py | 13 ++++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/diracx-core/pyproject.toml b/diracx-core/pyproject.toml index f5e8f5868..39dda028e 100644 --- a/diracx-core/pyproject.toml +++ b/diracx-core/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "pydantic-settings", "pyyaml", "sh", + "uuid_utils" ] dynamic = ["version"] diff --git a/diracx-core/src/diracx/core/utils.py b/diracx-core/src/diracx/core/utils.py index 5c92fe7f4..05a171748 100644 --- a/diracx-core/src/diracx/core/utils.py +++ b/diracx-core/src/diracx/core/utils.py @@ -287,15 +287,26 @@ def extract_timestamp_from_uuid7(uuid_str: str) -> datetime: def recursive_dict_merge(x: T_DICTS, y: T_DICTS) -> T_DICTS: + """Used to merge two dictionaries with different types in it. + + Use case: pilot secrets constraints that we update. + """ result: dict[str, Any] = dict(x) for k, v in y.items(): if k in result: if isinstance(result[k], dict) and isinstance(v, dict): + # If it's a dict, recursion result[k] = recursive_dict_merge(result[k], v) elif isinstance(result[k], list) and isinstance(v, list): - result[k] = result[k] + v + # Prevent duplicates (costy operation, but done no that many times) + result[k] = list(set(result[k] + v)) + elif isinstance(result[k], set) and isinstance(v, set): + # If it's a set, update values + result[k].update(v) else: + # Other types are: int, str, byte, float, bool + # No need to handle it differently, just replace the value result[k] = v else: result[k] = v From b15c91d94f9ab9f6498afc97ca04bd2503fee2f4 Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Wed, 6 Aug 2025 14:05:44 +0200 Subject: [PATCH 09/11] feat: Add a 'create secrets' command in the CLI --- diracx-cli/pyproject.toml | 1 + diracx-cli/src/diracx/cli/pilots.py | 66 +++++++++++++++++++ .../src/diracx/client/patches/pilots/aio.py | 11 +++- .../diracx/client/patches/pilots/common.py | 30 +++++++++ .../src/diracx/client/patches/pilots/sync.py | 11 +++- 5 files changed, 115 insertions(+), 4 deletions(-) create mode 100644 diracx-cli/src/diracx/cli/pilots.py diff --git a/diracx-cli/pyproject.toml b/diracx-cli/pyproject.toml index f3ab83f94..8e2d3d505 100644 --- a/diracx-cli/pyproject.toml +++ b/diracx-cli/pyproject.toml @@ -39,6 +39,7 @@ dirac = "diracx.cli:app" [project.entry-points."diracx.cli"] jobs = "diracx.cli.jobs:app" config = "diracx.cli.config:app" +pilots = "diracx.cli.pilots:app" [project.entry-points."diracx.cli.hidden"] internal = "diracx.cli.internal:app" diff --git a/diracx-cli/src/diracx/cli/pilots.py b/diracx-cli/src/diracx/cli/pilots.py new file mode 100644 index 000000000..25dc6261d --- /dev/null +++ b/diracx-cli/src/diracx/cli/pilots.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +__all__ = ("app",) + +import asyncio +import json +from typing import Annotated, Optional + +import typer + +from diracx.client.aio import AsyncDiracClient + +from .utils import AsyncTyper + +app = AsyncTyper() + + +async def installation_metadata(): + async with AsyncDiracClient() as api: + return await api.well_known.get_installation_metadata() + + +def vo_callback(vo: str | None) -> str: + metadata = asyncio.run(installation_metadata()) + vos = list(metadata.virtual_organizations) + if not vo: + raise typer.BadParameter( + f"VO must be specified, available options are: {' '.join(vos)}" + ) + if vo not in vos: + raise typer.BadParameter( + f"Unknown VO {vo}, available options are: {' '.join(vos)}" + ) + return vo + + +@app.async_command() +async def generate_pilot_secrets( + vo: Annotated[ + str, + typer.Argument(callback=vo_callback, help="Virtual Organization name"), + ], + n: Annotated[ + int, + typer.Argument(help="Number of secrets to generate."), + ], + expiration_minutes: Optional[int] = typer.Option( + 60, + help="Expiration in minutes of the secrets.", + ), + max_use: Optional[int] = typer.Option( + 60, + help="Number of uses max for a secret.", + ), +): + async with AsyncDiracClient() as api: + secrets = await api.pilots.create_pilot_secrets( + n=n, + expiration_minutes=expiration_minutes, + pilot_secret_use_count_max=max_use, + vo=vo, + ) + # Convert each model to dict + secrets_dict = [secret.as_dict() for secret in secrets] + + print(json.dumps(secrets_dict, indent=2)) diff --git a/diracx-client/src/diracx/client/patches/pilots/aio.py b/diracx-client/src/diracx/client/patches/pilots/aio.py index 56d278a1f..1e2c21fb6 100644 --- a/diracx-client/src/diracx/client/patches/pilots/aio.py +++ b/diracx-client/src/diracx/client/patches/pilots/aio.py @@ -16,16 +16,18 @@ from azure.core.tracing.decorator_async import distributed_trace_async from ..._generated.aio.operations._operations import PilotsOperations as _PilotsOperations -from ..._generated.models._models import PilotCredentialsInfo +from ..._generated.models._models import PilotCredentialsInfo, PilotSecretsInfo from .common import ( make_search_body, make_summary_body, make_add_pilot_stamps_body, make_update_pilot_fields_body, + make_create_pilot_secrets_body, SearchKwargs, SummaryKwargs, AddPilotStampsKwargs, - UpdatePilotFieldsKwargs + UpdatePilotFieldsKwargs, + CreatePilotSecretsKwargs, ) # We're intentionally ignoring overrides here because we want to change the interface. @@ -52,3 +54,8 @@ async def add_pilot_stamps(self, **kwargs: Unpack[AddPilotStampsKwargs]) -> list async def update_pilot_fields(self, **kwargs: Unpack[UpdatePilotFieldsKwargs]) -> None: """TODO""" return await super().update_pilot_fields(**make_update_pilot_fields_body(**kwargs)) + + @distributed_trace_async + async def create_pilot_secrets(self, **kwargs: Unpack[CreatePilotSecretsKwargs]) -> list[PilotSecretsInfo]: + """TODO""" + return await super().create_pilot_secrets(**make_create_pilot_secrets_body(**kwargs)) diff --git a/diracx-client/src/diracx/client/patches/pilots/common.py b/diracx-client/src/diracx/client/patches/pilots/common.py index 258bc42f8..a7476317a 100644 --- a/diracx-client/src/diracx/client/patches/pilots/common.py +++ b/diracx-client/src/diracx/client/patches/pilots/common.py @@ -11,6 +11,8 @@ "make_add_pilot_stamps_body", "UpdatePilotFieldsKwargs", "make_update_pilot_fields_body" + "CreatePilotSecretsKwargs", + "make_create_pilot_secrets_body" ] import json @@ -146,3 +148,31 @@ def make_update_pilot_fields_body(**kwargs: Unpack[UpdatePilotFieldsKwargs]) -> result: UnderlyingUpdatePilotFields = {"body": BytesIO(json.dumps(body).encode("utf-8"))} result.update(cast(ResponseExtra, kwargs)) return result + +# ------------------ CreatePilotSecrets ------------------ + +class CreatePilotSecretsBody(TypedDict, total=False): + n: int + expiration_minutes: int | None + pilot_secret_use_count_max: int | None + vo: str + +class CreatePilotSecretsKwargs(CreatePilotSecretsBody, ResponseExtra): ... + +class UnderlyingCreatePilotSecrets(ResponseExtra, total=False): + # FIXME: The autorest-generated has a bug that it expected IO[bytes] despite + # the code being generated to support IO[bytes] | bytes. + body: IO[bytes] + +def make_create_pilot_secrets_body(**kwargs: Unpack[CreatePilotSecretsKwargs]) -> UnderlyingCreatePilotSecrets: + body: CreatePilotSecretsBody = {} + for key in CreatePilotSecretsBody.__optional_keys__: + if key not in kwargs: + continue + key = cast(Literal["n", "expiration_minutes", "pilot_secret_use_count_max", "vo"], key) + value = kwargs.pop(key) + if value is not None: + body[key] = value + result: UnderlyingCreatePilotSecrets = {"body": BytesIO(json.dumps(body).encode("utf-8"))} + result.update(cast(ResponseExtra, kwargs)) + return result diff --git a/diracx-client/src/diracx/client/patches/pilots/sync.py b/diracx-client/src/diracx/client/patches/pilots/sync.py index e3059013b..31ab15df5 100644 --- a/diracx-client/src/diracx/client/patches/pilots/sync.py +++ b/diracx-client/src/diracx/client/patches/pilots/sync.py @@ -16,16 +16,18 @@ from azure.core.tracing.decorator import distributed_trace from ..._generated.operations._operations import PilotsOperations as _PilotsOperations -from ..._generated.models._models import PilotCredentialsInfo +from ..._generated.models._models import PilotCredentialsInfo, PilotSecretsInfo from .common import ( make_search_body, make_summary_body, make_add_pilot_stamps_body, make_update_pilot_fields_body, + make_create_pilot_secrets_body, SearchKwargs, SummaryKwargs, AddPilotStampsKwargs, - UpdatePilotFieldsKwargs + UpdatePilotFieldsKwargs, + CreatePilotSecretsKwargs, ) # We're intentionally ignoring overrides here because we want to change the interface. @@ -52,3 +54,8 @@ def add_pilot_stamps(self, **kwargs: Unpack[AddPilotStampsKwargs]) -> list[Pilot def update_pilot_fields(self, **kwargs: Unpack[UpdatePilotFieldsKwargs]) -> None: """TODO""" return super().update_pilot_fields(**make_update_pilot_fields_body(**kwargs)) + + @distributed_trace + def create_pilot_secrets(self, **kwargs: Unpack[CreatePilotSecretsKwargs]) -> list[PilotSecretsInfo]: + """TODO""" + return super().create_pilot_secrets(**make_create_pilot_secrets_body(**kwargs)) From 9298ee64c5ea336a5664c2c9f659f0239372e3ce Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Thu, 7 Aug 2025 10:05:36 +0200 Subject: [PATCH 10/11] test: Test CI with pilot secrets --- .github/workflows/integration.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 28e3d8a05..800438531 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -21,7 +21,7 @@ jobs: fail-fast: false matrix: dirac-branch: - - robin-migrate-client + - robin-pilot-registrations steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 From 54dae47c4dd4e404b61f512826045cbe9968e090 Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Thu, 7 Aug 2025 11:22:35 +0200 Subject: [PATCH 11/11] fix: Generate autorest client --- .../client/_generated/models/__init__.py | 4 +- .../client/_generated/models/_models.py | 60 ++++++++++--------- .../client/_generated/models/__init__.py | 4 +- .../client/_generated/models/_models.py | 60 ++++++++++--------- 4 files changed, 66 insertions(+), 62 deletions(-) diff --git a/diracx-client/src/diracx/client/_generated/models/__init__.py b/diracx-client/src/diracx/client/_generated/models/__init__.py index a37bc168a..4c58bddf6 100644 --- a/diracx-client/src/diracx/client/_generated/models/__init__.py +++ b/diracx-client/src/diracx/client/_generated/models/__init__.py @@ -14,9 +14,9 @@ from ._models import ( # type: ignore BodyAuthGetOidcToken, BodyAuthGetOidcTokenGrantType, + BodyAuthRefreshPilotTokens, BodyJobsRescheduleJobs, BodyJobsUnassignBulkJobsSandboxes, - BodyAuthRefreshPilotTokens, BodyPilotsAddPilotStamps, BodyPilotsCreatePilotSecrets, BodyPilotsUpdatePilotFields, @@ -78,9 +78,9 @@ __all__ = [ "BodyAuthGetOidcToken", "BodyAuthGetOidcTokenGrantType", + "BodyAuthRefreshPilotTokens", "BodyJobsRescheduleJobs", "BodyJobsUnassignBulkJobsSandboxes", - "BodyAuthRefreshPilotTokens", "BodyPilotsAddPilotStamps", "BodyPilotsCreatePilotSecrets", "BodyPilotsUpdatePilotFields", diff --git a/diracx-client/src/diracx/client/_generated/models/_models.py b/diracx-client/src/diracx/client/_generated/models/_models.py index 0211ac35a..cd2c58667 100644 --- a/diracx-client/src/diracx/client/_generated/models/_models.py +++ b/diracx-client/src/diracx/client/_generated/models/_models.py @@ -94,34 +94,41 @@ class BodyAuthGetOidcTokenGrantType(_serialization.Model): """OAuth2 Grant type.""" -class BodyJobsRescheduleJobs(_serialization.Model): - """Body_jobs_reschedule_jobs. +class BodyAuthRefreshPilotTokens(_serialization.Model): + """Body_auth_refresh_pilot_tokens. All required parameters must be populated in order to send to server. - :ivar job_ids: Job Ids. Required. - :vartype job_ids: list[int] + :ivar refresh_token: Refresh Token given at login by DiracX. Required. + :vartype refresh_token: str + :ivar pilot_stamp: Pilot stamp. Required. + :vartype pilot_stamp: str """ _validation = { - "job_ids": {"required": True}, + "refresh_token": {"required": True}, + "pilot_stamp": {"required": True}, } _attribute_map = { - "job_ids": {"key": "job_ids", "type": "[int]"}, + "refresh_token": {"key": "refresh_token", "type": "str"}, + "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, } - def __init__(self, *, job_ids: List[int], **kwargs: Any) -> None: + def __init__(self, *, refresh_token: str, pilot_stamp: str, **kwargs: Any) -> None: """ - :keyword job_ids: Job Ids. Required. - :paramtype job_ids: list[int] + :keyword refresh_token: Refresh Token given at login by DiracX. Required. + :paramtype refresh_token: str + :keyword pilot_stamp: Pilot stamp. Required. + :paramtype pilot_stamp: str """ super().__init__(**kwargs) - self.job_ids = job_ids + self.refresh_token = refresh_token + self.pilot_stamp = pilot_stamp -class BodyJobsUnassignBulkJobsSandboxes(_serialization.Model): - """Body_jobs_unassign_bulk_jobs_sandboxes. +class BodyJobsRescheduleJobs(_serialization.Model): + """Body_jobs_reschedule_jobs. All required parameters must be populated in order to send to server. @@ -144,37 +151,32 @@ def __init__(self, *, job_ids: List[int], **kwargs: Any) -> None: """ super().__init__(**kwargs) self.job_ids = job_ids -class BodyAuthRefreshPilotTokens(_serialization.Model): - """Body_auth_refresh_pilot_tokens. + + +class BodyJobsUnassignBulkJobsSandboxes(_serialization.Model): + """Body_jobs_unassign_bulk_jobs_sandboxes. All required parameters must be populated in order to send to server. - :ivar refresh_token: Refresh Token given at login by DiracX. Required. - :vartype refresh_token: str - :ivar pilot_stamp: Pilot stamp. Required. - :vartype pilot_stamp: str + :ivar job_ids: Job Ids. Required. + :vartype job_ids: list[int] """ _validation = { - "refresh_token": {"required": True}, - "pilot_stamp": {"required": True}, + "job_ids": {"required": True}, } _attribute_map = { - "refresh_token": {"key": "refresh_token", "type": "str"}, - "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, + "job_ids": {"key": "job_ids", "type": "[int]"}, } - def __init__(self, *, refresh_token: str, pilot_stamp: str, **kwargs: Any) -> None: + def __init__(self, *, job_ids: List[int], **kwargs: Any) -> None: """ - :keyword refresh_token: Refresh Token given at login by DiracX. Required. - :paramtype refresh_token: str - :keyword pilot_stamp: Pilot stamp. Required. - :paramtype pilot_stamp: str + :keyword job_ids: Job Ids. Required. + :paramtype job_ids: list[int] """ super().__init__(**kwargs) - self.refresh_token = refresh_token - self.pilot_stamp = pilot_stamp + self.job_ids = job_ids class BodyPilotsAddPilotStamps(_serialization.Model): diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py index a0e841992..b8a89752e 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py @@ -14,9 +14,9 @@ from ._models import ( # type: ignore BodyAuthGetOidcToken, BodyAuthGetOidcTokenGrantType, + BodyAuthRefreshPilotTokens, BodyJobsRescheduleJobs, BodyJobsUnassignBulkJobsSandboxes, - BodyAuthRefreshPilotTokens, BodyPilotsAddPilotStamps, BodyPilotsCreatePilotSecrets, BodyPilotsUpdatePilotFields, @@ -78,9 +78,9 @@ __all__ = [ "BodyAuthGetOidcToken", "BodyAuthGetOidcTokenGrantType", + "BodyAuthRefreshPilotTokens", "BodyJobsRescheduleJobs", "BodyJobsUnassignBulkJobsSandboxes", - "BodyAuthRefreshPilotTokens", "BodyPilotsAddPilotStamps", "BodyPilotsCreatePilotSecrets", "BodyPilotsUpdatePilotFields", diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py index f07d962f1..0d752e6e7 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py @@ -94,34 +94,41 @@ class BodyAuthGetOidcTokenGrantType(_serialization.Model): """OAuth2 Grant type.""" -class BodyJobsRescheduleJobs(_serialization.Model): - """Body_jobs_reschedule_jobs. +class BodyAuthRefreshPilotTokens(_serialization.Model): + """Body_auth_refresh_pilot_tokens. All required parameters must be populated in order to send to server. - :ivar job_ids: Job Ids. Required. - :vartype job_ids: list[int] + :ivar refresh_token: Refresh Token given at login by DiracX. Required. + :vartype refresh_token: str + :ivar pilot_stamp: Pilot stamp. Required. + :vartype pilot_stamp: str """ _validation = { - "job_ids": {"required": True}, + "refresh_token": {"required": True}, + "pilot_stamp": {"required": True}, } _attribute_map = { - "job_ids": {"key": "job_ids", "type": "[int]"}, + "refresh_token": {"key": "refresh_token", "type": "str"}, + "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, } - def __init__(self, *, job_ids: List[int], **kwargs: Any) -> None: + def __init__(self, *, refresh_token: str, pilot_stamp: str, **kwargs: Any) -> None: """ - :keyword job_ids: Job Ids. Required. - :paramtype job_ids: list[int] + :keyword refresh_token: Refresh Token given at login by DiracX. Required. + :paramtype refresh_token: str + :keyword pilot_stamp: Pilot stamp. Required. + :paramtype pilot_stamp: str """ super().__init__(**kwargs) - self.job_ids = job_ids + self.refresh_token = refresh_token + self.pilot_stamp = pilot_stamp -class BodyJobsUnassignBulkJobsSandboxes(_serialization.Model): - """Body_jobs_unassign_bulk_jobs_sandboxes. +class BodyJobsRescheduleJobs(_serialization.Model): + """Body_jobs_reschedule_jobs. All required parameters must be populated in order to send to server. @@ -144,37 +151,32 @@ def __init__(self, *, job_ids: List[int], **kwargs: Any) -> None: """ super().__init__(**kwargs) self.job_ids = job_ids -class BodyAuthRefreshPilotTokens(_serialization.Model): - """Body_auth_refresh_pilot_tokens. + + +class BodyJobsUnassignBulkJobsSandboxes(_serialization.Model): + """Body_jobs_unassign_bulk_jobs_sandboxes. All required parameters must be populated in order to send to server. - :ivar refresh_token: Refresh Token given at login by DiracX. Required. - :vartype refresh_token: str - :ivar pilot_stamp: Pilot stamp. Required. - :vartype pilot_stamp: str + :ivar job_ids: Job Ids. Required. + :vartype job_ids: list[int] """ _validation = { - "refresh_token": {"required": True}, - "pilot_stamp": {"required": True}, + "job_ids": {"required": True}, } _attribute_map = { - "refresh_token": {"key": "refresh_token", "type": "str"}, - "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, + "job_ids": {"key": "job_ids", "type": "[int]"}, } - def __init__(self, *, refresh_token: str, pilot_stamp: str, **kwargs: Any) -> None: + def __init__(self, *, job_ids: List[int], **kwargs: Any) -> None: """ - :keyword refresh_token: Refresh Token given at login by DiracX. Required. - :paramtype refresh_token: str - :keyword pilot_stamp: Pilot stamp. Required. - :paramtype pilot_stamp: str + :keyword job_ids: Job Ids. Required. + :paramtype job_ids: list[int] """ super().__init__(**kwargs) - self.refresh_token = refresh_token - self.pilot_stamp = pilot_stamp + self.job_ids = job_ids class BodyPilotsAddPilotStamps(_serialization.Model):