Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 85 additions & 24 deletions src/firecrest/compute/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,18 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from fastapi import status, Path, HTTPException, Depends, Query
import asyncio
from time import time
from asyncssh import SSHClientProcess
from fastapi import (
WebSocket,
WebSocketDisconnect,
status,
Path,
HTTPException,
Depends,
Query,
)
from typing import Any, Annotated

# helpers
Expand Down Expand Up @@ -35,6 +46,11 @@
dependencies=[Depends(APIAuthDependency(authorize=True))],
)

router_ws = create_router(
prefix="/compute/{system_name}/jobs/{job_id}/attach",
tags=["compute"],
)


@router.post(
"",
Expand Down Expand Up @@ -79,14 +95,17 @@ async def get_jobs(
Path(alias="system_name", description="Target system"),
Depends(SchedulerClientDependency()),
],
allusers: Annotated[bool, Query(description="If set to `true` returns all jobs visible by the current user, otherwise only the current user owned jobs")] = False
allusers: Annotated[
bool,
Query(
description="If set to `true` returns all jobs visible by the current user, otherwise only the current user owned jobs"
),
] = False,
) -> Any:
username = ApiAuthHelper.get_auth().username
access_token = ApiAuthHelper.get_access_token()
jobs = await scheduler_client.get_jobs(
username=username,
jwt_token=access_token,
allusers=allusers
username=username, jwt_token=access_token, allusers=allusers
)
return {"jobs": jobs}

Expand All @@ -104,14 +123,12 @@ async def get_job(
SchedulerBaseClient,
Path(alias="system_name", description="Target system"),
Depends(SchedulerClientDependency()),
]
],
) -> Any:
username = ApiAuthHelper.get_auth().username
access_token = ApiAuthHelper.get_access_token()
jobs = await scheduler_client.get_job(
job_id=job_id,
username=username,
jwt_token=access_token
job_id=job_id, username=username, jwt_token=access_token
)
if jobs is None:
raise HTTPException(
Expand Down Expand Up @@ -147,29 +164,73 @@ async def get_job_metadata(
return {"jobs": jobs}


@router.put(
"/{job_id}/attach",
description="Attach a procces to a job by `{job_id}`",
status_code=status.HTTP_204_NO_CONTENT,
response_description="Process attached succesfully",
)
@router_ws.websocket("")
async def attach(
websocket: WebSocket,
job_id: Annotated[str, Path(description="Job id", pattern="^[a-zA-Z0-9]+$")],
job_attach: PostJobAttachRequest,
cmd: str,
token: str,
scheduler_client: Annotated[
SchedulerBaseClient,
Path(alias="system_name", description="Target system"),
Depends(SchedulerClientDependency()),
],
) -> None:
username = ApiAuthHelper.get_auth().username
access_token = ApiAuthHelper.get_access_token()
await scheduler_client.attach_command(
command=job_attach.command,
job_id=job_id,
username=username,
jwt_token=access_token,
)

# TODO: Refactor authentication dependency to support JWT as query param
username = "fireuser" # ApiAuthHelper.get_auth().username
access_token = token # ApiAuthHelper.get_access_token()

await websocket.accept()

async def read_stdout(process):
async for line in process.stdout:
await websocket.send_text(line)

async def read_stderr(process):
async for line in process.stderr:
await websocket.send_text(line)

async def write_stdin(process):
while True:
data = await websocket.receive_text()
process.stdin.write(data)
await process.stdin.drain()

async def keep_alive(process: SSHClientProcess):
while True:
process.channel.get_connection().set_extra_info(**{"last_used": time()})
await asyncio.sleep(5)

try:
async with scheduler_client.attach_command_proccess(
command=cmd,
job_id=None if job_id == "0" else job_id,
username=username,
jwt_token=access_token,
) as process:
# Run all tasks concurrently
stdout_task = asyncio.create_task(read_stdout(process))
stderr_task = asyncio.create_task(read_stderr(process))
stdin_task = asyncio.create_task(write_stdin(process))
keep_alive_task = asyncio.create_task(keep_alive(process))

# Wait until one of the tasks ends (usually stdin via disconnect)
done, pending = await asyncio.wait(
[stdout_task, stderr_task, stdin_task, keep_alive_task],
return_when=asyncio.FIRST_COMPLETED,
)

# Cancel remaining tasks
for task in pending:
task.cancel()

except WebSocketDisconnect:
print("WebSocket disconnected")
except Exception as e:
await websocket.send_text(f"Error: {str(e)}")
await websocket.close()

return None


Expand Down
2 changes: 2 additions & 0 deletions src/firecrest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
router_liveness as status_liveness_router,
)
from firecrest.compute.router import router as compute_router
from firecrest.compute.router import router_ws as compute_router_ws
from firecrest.filesystem.router import router as filesystem_router
from lib.scheduler_clients import SlurmRestClient

Expand Down Expand Up @@ -173,6 +174,7 @@ def register_routes(app: FastAPI, settings: config.Settings):
app.include_router(status_system_router)
app.include_router(status_liveness_router)
app.include_router(compute_router)
app.include_router(compute_router_ws)
app.include_router(filesystem_router)


Expand Down
9 changes: 9 additions & 0 deletions src/lib/scheduler_clients/pbs/pbs_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,15 @@ async def attach_command(
"Interactive attach is not supported in PBS CLI client"
)

async def attach_command_proccess(
self,
command: str,
job_id: str,
username: str,
jwt_token: str,
) -> int | None:
pass

async def get_job(
self, job_id: str | None, username: str, jwt_token: str, allusers: bool = True
) -> List[PbsJob] | None:
Expand Down
17 changes: 10 additions & 7 deletions src/lib/scheduler_clients/scheduler_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,19 @@ async def attach_command(
pass

@abstractmethod
# Note: returns multiple jobs to deal with job_id duplicates (see Slurm doc)
async def get_job(
async def attach_command_proccess(
self,
command: str,
job_id: str,
username: str,
jwt_token: str,
allusers: bool = True
) -> int | None:
pass

@abstractmethod
# Note: returns multiple jobs to deal with job_id duplicates (see Slurm doc)
async def get_job(
self, job_id: str, username: str, jwt_token: str, allusers: bool = True
) -> List[JobModel]:
pass

Expand All @@ -58,10 +64,7 @@ async def get_job_metadata(

@abstractmethod
async def get_jobs(
self,
username: str,
jwt_token: str,
allusers: bool = False
self, username: str, jwt_token: str, allusers: bool = False
) -> List[JobModel] | None:
pass

Expand Down
10 changes: 10 additions & 0 deletions src/lib/scheduler_clients/slurm/slurm_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ async def attach_command(
) -> int | None:
pass

@abstractmethod
async def attach_command_proccess(
self,
command: str,
job_id: str,
username: str,
jwt_token: str,
) -> None:
pass

@abstractmethod
# Note: returns multiple jobs to deal with job_id duplicates (see Slurm doc)
async def get_job(
Expand Down
21 changes: 20 additions & 1 deletion src/lib/scheduler_clients/slurm/slurm_cli_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import asyncio
from contextlib import asynccontextmanager, contextmanager
from typing import List
from packaging.version import Version

Expand Down Expand Up @@ -53,6 +54,12 @@ async def __executed_ssh_cmd(self, username, jwt_token, command, stdin=None):
async with self.ssh_client.get_client(username, jwt_token) as client:
return await client.execute(command, stdin)

@asynccontextmanager
async def __create_ssh_cmd(self, username, jwt_token, command):
async with self.ssh_client.get_client(username, jwt_token) as client:
async with client.create(command) as process:
yield process

def __init__(
self,
ssh_client: SSHClientPool,
Expand All @@ -79,9 +86,21 @@ async def attach_command(
username: str,
jwt_token: str,
) -> int | None:
srun = SrunCommand(command=command, job_id=job_id, overlap=True)
srun = SrunCommand(command=command, job_id=job_id, overlap=(job_id is not None))
return await self.__executed_ssh_cmd(username, jwt_token, srun)

@asynccontextmanager
async def attach_command_proccess(
self,
command: str,
job_id: str,
username: str,
jwt_token: str,
):
srun = SrunCommand(command=command, job_id=job_id, overlap=True)
async with self.__create_ssh_cmd(username, jwt_token, srun) as process:
yield process

async def get_job(
self, job_id: str | None, username: str, jwt_token: str, allusers: bool = True
) -> List[SlurmJob] | None:
Expand Down
14 changes: 14 additions & 0 deletions src/lib/scheduler_clients/slurm/slurm_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import asynccontextmanager
from typing import List

from lib.scheduler_clients.slurm.models import (
Expand Down Expand Up @@ -65,6 +66,19 @@ async def attach_command(
command, job_id, username, jwt_token
)

@asynccontextmanager
async def attach_command_proccess(
self,
command: str,
job_id: str,
username: str,
jwt_token: str,
):
async with self.slurm_default_client.attach_command_proccess(
command, job_id, username, jwt_token
) as process:
yield process

async def get_job(
self, job_id: str | None, username: str, jwt_token: str, allusers: bool = True
) -> List[SlurmJob] | None:
Expand Down
9 changes: 9 additions & 0 deletions src/lib/scheduler_clients/slurm/slurm_rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,15 @@ async def attach_command(
) -> int | None:
pass

async def attach_command_proccess(
self,
command: str,
job_id: str,
username: str,
jwt_token: str,
) -> int | None:
pass

async def get_job(
self, job_id: str, username: str, jwt_token: str, allusers: bool = True
) -> List[SlurmJob] | None:
Expand Down
27 changes: 27 additions & 0 deletions src/lib/ssh_clients/ssh_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,33 @@ async def execute(self, command: BaseCommand, stdin: str = None):
except ChannelOpenError as e:
raise SSHConnectionError("Unable to open a new SSH channel.") from e

@asynccontextmanager
async def create(self, command: BaseCommand, stdin: str = None):
try:
# TODO: introduce stream timeout
# async with asyncio.timeout(self.execute_timeout * 100000):
command_line = command.get_command()
process = await self.conn.create_process(command_line, send_eof=False)

yield process

process.close()
await process.wait_closed()
# Log command
log_backend_command(command_line, process.exit_status)

except TimeoutError as e:
process.terminate()
process.stdin.write("\x03")
process.stdin.write_eof()
raise TimeoutLimitExceeded(
"Command execution timeout limit exceeded."
) from e
except ConnectionLost as e:
raise SSHConnectionError("Unable to establish SSH connection.") from e
except ChannelOpenError as e:
raise SSHConnectionError("Unable to open a new SSH channel.") from e

def reset_idle(
self,
) -> None:
Expand Down
Loading