From b6350f281e2b9993e6810053773f934879d19ba7 Mon Sep 17 00:00:00 2001 From: Elia Palme Date: Thu, 17 Jul 2025 09:04:35 +0200 Subject: [PATCH 1/2] Draft code for interactive attach end-point --- src/firecrest/compute/router.py | 109 ++++++++++++++---- src/firecrest/main.py | 2 + src/lib/scheduler_clients/pbs/pbs_client.py | 9 ++ .../scheduler_base_client.py | 17 +-- .../slurm/slurm_base_client.py | 10 ++ .../slurm/slurm_cli_client.py | 21 +++- .../scheduler_clients/slurm/slurm_client.py | 14 +++ .../slurm/slurm_rest_client.py | 9 ++ src/lib/ssh_clients/ssh_client.py | 27 +++++ 9 files changed, 186 insertions(+), 32 deletions(-) diff --git a/src/firecrest/compute/router.py b/src/firecrest/compute/router.py index c004f0c9..a3cc9e92 100644 --- a/src/firecrest/compute/router.py +++ b/src/firecrest/compute/router.py @@ -3,7 +3,18 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from fastapi import status, Path, HTTPException, Depends, Query +import asyncio +from time import time +from asyncssh import SSHClientProcess +from fastapi import ( + WebSocket, + WebSocketDisconnect, + status, + Path, + HTTPException, + Depends, + Query, +) from typing import Any, Annotated # helpers @@ -35,6 +46,11 @@ dependencies=[Depends(APIAuthDependency(authorize=True))], ) +router_ws = create_router( + prefix="/compute/{system_name}/jobs/{job_id}/attach", + tags=["compute"], +) + @router.post( "", @@ -79,14 +95,17 @@ async def get_jobs( Path(alias="system_name", description="Target system"), Depends(SchedulerClientDependency()), ], - allusers: Annotated[bool, Query(description="If set to `true` returns all jobs visible by the current user, otherwise only the current user owned jobs")] = False + allusers: Annotated[ + bool, + Query( + description="If set to `true` returns all jobs visible by the current user, otherwise only the current user owned jobs" + ), + ] = False, ) -> Any: username = ApiAuthHelper.get_auth().username access_token = ApiAuthHelper.get_access_token() jobs = await scheduler_client.get_jobs( - username=username, - jwt_token=access_token, - allusers=allusers + username=username, jwt_token=access_token, allusers=allusers ) return {"jobs": jobs} @@ -104,14 +123,12 @@ async def get_job( SchedulerBaseClient, Path(alias="system_name", description="Target system"), Depends(SchedulerClientDependency()), - ] + ], ) -> Any: username = ApiAuthHelper.get_auth().username access_token = ApiAuthHelper.get_access_token() jobs = await scheduler_client.get_job( - job_id=job_id, - username=username, - jwt_token=access_token + job_id=job_id, username=username, jwt_token=access_token ) if jobs is None: raise HTTPException( @@ -147,29 +164,73 @@ async def get_job_metadata( return {"jobs": jobs} -@router.put( - "/{job_id}/attach", - description="Attach a procces to a job by `{job_id}`", - status_code=status.HTTP_204_NO_CONTENT, - response_description="Process attached succesfully", -) +@router_ws.websocket("") async def attach( + websocket: WebSocket, job_id: Annotated[str, Path(description="Job id", pattern="^[a-zA-Z0-9]+$")], - job_attach: PostJobAttachRequest, + entrypoint: str, + token: str, scheduler_client: Annotated[ SchedulerBaseClient, Path(alias="system_name", description="Target system"), Depends(SchedulerClientDependency()), ], ) -> None: - username = ApiAuthHelper.get_auth().username - access_token = ApiAuthHelper.get_access_token() - await scheduler_client.attach_command( - command=job_attach.command, - job_id=job_id, - username=username, - jwt_token=access_token, - ) + + # TODO: Refactor authentication dependency to support JWT as query param + username = "fireuser" # ApiAuthHelper.get_auth().username + access_token = token # ApiAuthHelper.get_access_token() + + await websocket.accept() + + async def read_stdout(process): + async for line in process.stdout: + await websocket.send_text(line) + + async def read_stderr(process): + async for line in process.stderr: + await websocket.send_text(line) + + async def write_stdin(process): + while True: + data = await websocket.receive_text() + process.stdin.write(data) + await process.stdin.drain() + + async def keep_alive(process: SSHClientProcess): + while True: + process.channel.get_connection().set_extra_info(**{"last_used": time()}) + await asyncio.sleep(5) + + try: + async with scheduler_client.attach_command_proccess( + command=entrypoint, + job_id=None if job_id == "0" else job_id, + username=username, + jwt_token=access_token, + ) as process: + # Run all tasks concurrently + stdout_task = asyncio.create_task(read_stdout(process)) + stderr_task = asyncio.create_task(read_stderr(process)) + stdin_task = asyncio.create_task(write_stdin(process)) + keep_alive_task = asyncio.create_task(keep_alive(process)) + + # Wait until one of the tasks ends (usually stdin via disconnect) + done, pending = await asyncio.wait( + [stdout_task, stderr_task, stdin_task, keep_alive_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + # Cancel remaining tasks + for task in pending: + task.cancel() + + except WebSocketDisconnect: + print("WebSocket disconnected") + except Exception as e: + await websocket.send_text(f"Error: {str(e)}") + await websocket.close() + return None diff --git a/src/firecrest/main.py b/src/firecrest/main.py index c08185d2..21a25463 100644 --- a/src/firecrest/main.py +++ b/src/firecrest/main.py @@ -44,6 +44,7 @@ router_liveness as status_liveness_router, ) from firecrest.compute.router import router as compute_router +from firecrest.compute.router import router_ws as compute_router_ws from firecrest.filesystem.router import router as filesystem_router from lib.scheduler_clients import SlurmRestClient @@ -173,6 +174,7 @@ def register_routes(app: FastAPI, settings: config.Settings): app.include_router(status_system_router) app.include_router(status_liveness_router) app.include_router(compute_router) + app.include_router(compute_router_ws) app.include_router(filesystem_router) diff --git a/src/lib/scheduler_clients/pbs/pbs_client.py b/src/lib/scheduler_clients/pbs/pbs_client.py index 3d0f09d5..49179393 100644 --- a/src/lib/scheduler_clients/pbs/pbs_client.py +++ b/src/lib/scheduler_clients/pbs/pbs_client.py @@ -83,6 +83,15 @@ async def attach_command( "Interactive attach is not supported in PBS CLI client" ) + async def attach_command_proccess( + self, + command: str, + job_id: str, + username: str, + jwt_token: str, + ) -> int | None: + pass + async def get_job( self, job_id: str | None, username: str, jwt_token: str, allusers: bool = True ) -> List[PbsJob] | None: diff --git a/src/lib/scheduler_clients/scheduler_base_client.py b/src/lib/scheduler_clients/scheduler_base_client.py index 40535b3d..c3099bb6 100644 --- a/src/lib/scheduler_clients/scheduler_base_client.py +++ b/src/lib/scheduler_clients/scheduler_base_client.py @@ -39,13 +39,19 @@ async def attach_command( pass @abstractmethod - # Note: returns multiple jobs to deal with job_id duplicates (see Slurm doc) - async def get_job( + async def attach_command_proccess( self, + command: str, job_id: str, username: str, jwt_token: str, - allusers: bool = True + ) -> int | None: + pass + + @abstractmethod + # Note: returns multiple jobs to deal with job_id duplicates (see Slurm doc) + async def get_job( + self, job_id: str, username: str, jwt_token: str, allusers: bool = True ) -> List[JobModel]: pass @@ -58,10 +64,7 @@ async def get_job_metadata( @abstractmethod async def get_jobs( - self, - username: str, - jwt_token: str, - allusers: bool = False + self, username: str, jwt_token: str, allusers: bool = False ) -> List[JobModel] | None: pass diff --git a/src/lib/scheduler_clients/slurm/slurm_base_client.py b/src/lib/scheduler_clients/slurm/slurm_base_client.py index 47244b34..1e5dbe58 100644 --- a/src/lib/scheduler_clients/slurm/slurm_base_client.py +++ b/src/lib/scheduler_clients/slurm/slurm_base_client.py @@ -38,6 +38,16 @@ async def attach_command( ) -> int | None: pass + @abstractmethod + async def attach_command_proccess( + self, + command: str, + job_id: str, + username: str, + jwt_token: str, + ) -> None: + pass + @abstractmethod # Note: returns multiple jobs to deal with job_id duplicates (see Slurm doc) async def get_job( diff --git a/src/lib/scheduler_clients/slurm/slurm_cli_client.py b/src/lib/scheduler_clients/slurm/slurm_cli_client.py index b651c01a..d50dd24d 100644 --- a/src/lib/scheduler_clients/slurm/slurm_cli_client.py +++ b/src/lib/scheduler_clients/slurm/slurm_cli_client.py @@ -4,6 +4,7 @@ # SPDX-License-Identifier: BSD-3-Clause import asyncio +from contextlib import asynccontextmanager, contextmanager from typing import List from packaging.version import Version @@ -53,6 +54,12 @@ async def __executed_ssh_cmd(self, username, jwt_token, command, stdin=None): async with self.ssh_client.get_client(username, jwt_token) as client: return await client.execute(command, stdin) + @asynccontextmanager + async def __create_ssh_cmd(self, username, jwt_token, command): + async with self.ssh_client.get_client(username, jwt_token) as client: + async with client.create(command) as process: + yield process + def __init__( self, ssh_client: SSHClientPool, @@ -79,9 +86,21 @@ async def attach_command( username: str, jwt_token: str, ) -> int | None: - srun = SrunCommand(command=command, job_id=job_id, overlap=True) + srun = SrunCommand(command=command, job_id=job_id, overlap=(job_id is not None)) return await self.__executed_ssh_cmd(username, jwt_token, srun) + @asynccontextmanager + async def attach_command_proccess( + self, + command: str, + job_id: str, + username: str, + jwt_token: str, + ): + srun = SrunCommand(command=command, job_id=job_id, overlap=True) + async with self.__create_ssh_cmd(username, jwt_token, srun) as process: + yield process + async def get_job( self, job_id: str | None, username: str, jwt_token: str, allusers: bool = True ) -> List[SlurmJob] | None: diff --git a/src/lib/scheduler_clients/slurm/slurm_client.py b/src/lib/scheduler_clients/slurm/slurm_client.py index a3a3ff19..e730a302 100644 --- a/src/lib/scheduler_clients/slurm/slurm_client.py +++ b/src/lib/scheduler_clients/slurm/slurm_client.py @@ -1,3 +1,4 @@ +from contextlib import asynccontextmanager from typing import List from lib.scheduler_clients.slurm.models import ( @@ -65,6 +66,19 @@ async def attach_command( command, job_id, username, jwt_token ) + @asynccontextmanager + async def attach_command_proccess( + self, + command: str, + job_id: str, + username: str, + jwt_token: str, + ): + async with self.slurm_default_client.attach_command_proccess( + command, job_id, username, jwt_token + ) as process: + yield process + async def get_job( self, job_id: str | None, username: str, jwt_token: str, allusers: bool = True ) -> List[SlurmJob] | None: diff --git a/src/lib/scheduler_clients/slurm/slurm_rest_client.py b/src/lib/scheduler_clients/slurm/slurm_rest_client.py index 2a877593..0e26bb1d 100644 --- a/src/lib/scheduler_clients/slurm/slurm_rest_client.py +++ b/src/lib/scheduler_clients/slurm/slurm_rest_client.py @@ -139,6 +139,15 @@ async def attach_command( ) -> int | None: pass + async def attach_command_proccess( + self, + command: str, + job_id: str, + username: str, + jwt_token: str, + ) -> int | None: + pass + async def get_job( self, job_id: str, username: str, jwt_token: str, allusers: bool = True ) -> List[SlurmJob] | None: diff --git a/src/lib/ssh_clients/ssh_client.py b/src/lib/ssh_clients/ssh_client.py index a124c3fb..bb4f1c36 100644 --- a/src/lib/ssh_clients/ssh_client.py +++ b/src/lib/ssh_clients/ssh_client.py @@ -109,6 +109,33 @@ async def execute(self, command: BaseCommand, stdin: str = None): except ChannelOpenError as e: raise SSHConnectionError("Unable to open a new SSH channel.") from e + @asynccontextmanager + async def create(self, command: BaseCommand, stdin: str = None): + try: + # TODO: introduce stream timeout + # async with asyncio.timeout(self.execute_timeout * 100000): + command_line = command.get_command() + process = await self.conn.create_process(command_line, send_eof=False) + + yield process + + process.close() + await process.wait_closed() + # Log command + log_backend_command(command_line, process.exit_status) + + except TimeoutError as e: + process.terminate() + process.stdin.write("\x03") + process.stdin.write_eof() + raise TimeoutLimitExceeded( + "Command execution timeout limit exceeded." + ) from e + except ConnectionLost as e: + raise SSHConnectionError("Unable to establish SSH connection.") from e + except ChannelOpenError as e: + raise SSHConnectionError("Unable to open a new SSH channel.") from e + def reset_idle( self, ) -> None: From 34260fc42563771d8d38da7daeb74378b41f0569 Mon Sep 17 00:00:00 2001 From: Elia Palme Date: Thu, 17 Jul 2025 09:27:12 +0200 Subject: [PATCH 2/2] rename the command param --- src/firecrest/compute/router.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/firecrest/compute/router.py b/src/firecrest/compute/router.py index a3cc9e92..36ec832c 100644 --- a/src/firecrest/compute/router.py +++ b/src/firecrest/compute/router.py @@ -168,7 +168,7 @@ async def get_job_metadata( async def attach( websocket: WebSocket, job_id: Annotated[str, Path(description="Job id", pattern="^[a-zA-Z0-9]+$")], - entrypoint: str, + cmd: str, token: str, scheduler_client: Annotated[ SchedulerBaseClient, @@ -204,7 +204,7 @@ async def keep_alive(process: SSHClientProcess): try: async with scheduler_client.attach_command_proccess( - command=entrypoint, + command=cmd, job_id=None if job_id == "0" else job_id, username=username, jwt_token=access_token,