diff --git a/jupyter_server/gateway/managers.py b/jupyter_server/gateway/managers.py index daa6f99213..64893eb688 100644 --- a/jupyter_server/gateway/managers.py +++ b/jupyter_server/gateway/managers.py @@ -6,11 +6,16 @@ import asyncio import datetime +import inspect import json import os +import time +import typing as t from queue import Empty, Queue from threading import Thread from time import monotonic +from turtle import st +from types import CoroutineType, coroutine from typing import TYPE_CHECKING, Any, Optional, cast import websocket @@ -642,6 +647,8 @@ async def get_msg(self, *args: Any, **kwargs: Any) -> dict[str, Any]: def send(self, msg: dict[str, Any]) -> None: """Send a message to the queue.""" + if "channel" not in msg: + msg["channel"] = self.channel_name message = json.dumps(msg, default=ChannelQueue.serialize_datetime).replace(" bool: """Whether the queue is alive.""" return self.channel_socket is not None + async def msg_ready(self) -> bool: + return not self.empty() + class HBChannelQueue(ChannelQueue): """A queue for the heartbeat channel.""" @@ -877,5 +887,187 @@ def _route_responses(self): self.log.debug("Response router thread exiting...") + async def _maybe_awaitable(self, func_result): + """Helper to handle potentially awaitable results""" + if inspect.isawaitable(func_result): + await func_result + + async def _handle_iopub_stdin_messages( + self, + msg_id: str, + output_hook: t.Optional[t.Callable[[dict[str, t.Any]], t.Any]], + stdin_hook: t.Optional[t.Callable[[dict[str, t.Any]], t.Any]], + timeout: t.Optional[float], + allow_stdin: bool, + start_time: float, + ) -> None: + """Handle IOPub messages until idle state""" + while True: + # Calculate remaining timeout + if timeout is not None: + elapsed = time.monotonic() - start_time + remaining = max(0, timeout - elapsed) + if remaining <= 0: + raise TimeoutError("Timeout in IOPub handling") + else: + remaining = None + if stdin_hook is not None and allow_stdin: + await self._handle_stdin_messages(stdin_hook, allow_stdin) + try: + msg = await self.iopub_channel.get_msg(timeout=remaining) + except Exception as e: + self.log.warning(f"err ({e})") + + if msg["parent_header"].get("msg_id") != msg_id: + continue + + if output_hook is not None: + await self._maybe_awaitable(output_hook(msg)) + + if ( + msg["header"]["msg_type"] == "status" + and msg["content"].get("execution_state") == "idle" + ): + break + + async def _handle_stdin_messages( + self, + stdin_hook: t.Callable[[dict[str, t.Any]], t.Any], + allow_stdin: bool, + ) -> None: + """Handle stdin messages until iopub is idle""" + if not allow_stdin: + return + try: + msg = await self.stdin_channel.get_msg(timeout=0.01) + self.log.info(f"stdin msg: {msg},{type(msg)}") + await self._maybe_awaitable(stdin_hook(msg)) + except (Empty, TimeoutError): + pass + except Exception: + self.log.warning("Error handling stdin message", exc_info=True) + + async def _wait_for_execution_reply( + self, msg_id: str, timeout: t.Optional[float], start_time: float + ) -> dict[str, t.Any]: + """Wait for execution reply from shell or control channel""" + # Calculate remaining timeout + if timeout is not None: + elapsed = time.monotonic() - start_time + remaining_timeout = max(0, timeout - elapsed) + if remaining_timeout <= 0: + raise TimeoutError("Timeout waiting for reply") + else: + remaining_timeout = None + + deadline = time.monotonic() + remaining_timeout if remaining_timeout else None + + while True: + if deadline: + remaining = max(0, deadline - time.monotonic()) + if remaining <= 0: + raise TimeoutError("Timeout waiting for reply") + else: + remaining = None + + # Listen to both shell and control channels + reply_task = asyncio.create_task(self.shell_channel.get_msg(timeout=remaining)) + control_task = asyncio.create_task(self.control_channel.get_msg(timeout=remaining)) + + try: + done, pending = await asyncio.wait( + [reply_task, control_task], + timeout=remaining, + return_when=asyncio.FIRST_COMPLETED, + ) + + # Cancel pending tasks + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + if not done: + raise TimeoutError("Timeout waiting for reply") + + for task in done: + try: + msg: dict[str, t.Any] = task.result() + if msg["parent_header"].get("msg_id") == msg_id: + return msg + except Exception: + continue + + except asyncio.TimeoutError as err: + reply_task.cancel() + control_task.cancel() + raise TimeoutError("Timeout waiting for reply") from err + + async def execute_interactive( + self, + code: str, + silent: bool = False, + store_history: bool = True, + user_expressions: t.Optional[dict[str, t.Any]] = None, + allow_stdin: t.Optional[bool] = None, + stop_on_error: bool = True, + timeout: t.Optional[float] = None, + output_hook: t.Optional[t.Callable[[dict[str, t.Any]], t.Any]] = None, + stdin_hook: t.Optional[t.Callable[[dict[str, t.Any]], t.Any]] = None, + ) -> dict[str, t.Any]: # type: ignore[override] # Reason: base class sets `execute_interactive` via assignment, so mypy cannot infer override compatibility + """Execute code in the kernel interactively via gateway""" + + # Channel alive checks + if not self.iopub_channel.is_alive(): + raise RuntimeError("IOPub channel must be running to receive output") + + # Prepare defaults + if allow_stdin is None: + allow_stdin = self.allow_stdin + + if output_hook is None: + output_hook = self._output_hook_default + if stdin_hook is None: + stdin_hook = self._stdin_hook_default + + # Execute the code + msg_id = self.execute( + code=code, + silent=silent, + store_history=store_history, + user_expressions=user_expressions, + allow_stdin=allow_stdin, + stop_on_error=stop_on_error, + ) + + # Setup coordination + start_time = time.monotonic() + + try: + # Handle IOPub messages until idle + iopub_task = asyncio.create_task( + self._handle_iopub_stdin_messages( + msg_id, output_hook, stdin_hook, timeout, allow_stdin, start_time + ), + name="handle_iopub_stdin_messages", + ) + await iopub_task + # Get the execution reply + reply = await self._wait_for_execution_reply(msg_id, timeout, start_time) + return reply + + except asyncio.CancelledError: + raise + except TimeoutError: + raise + except Exception as e: + self.log.error( + f"Error during interactive execution: {e}, msg_id: {msg_id}", + exc_info=True, + ) + raise RuntimeError(f"Error in interactive execution: {e}") from e + KernelClientABC.register(GatewayKernelClient) diff --git a/tests/test_gateway.py b/tests/test_gateway.py index 00aa64f111..d7a6756a77 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -25,8 +25,16 @@ from traitlets.config import Config from jupyter_server.gateway.connections import GatewayWebSocketConnection -from jupyter_server.gateway.gateway_client import GatewayTokenRenewerBase, NoOpTokenRenewer -from jupyter_server.gateway.managers import ChannelQueue, GatewayClient, GatewayKernelManager +from jupyter_server.gateway.gateway_client import ( + GatewayTokenRenewerBase, + NoOpTokenRenewer, +) +from jupyter_server.gateway.managers import ( + ChannelQueue, + GatewayClient, + GatewayKernelClient, + GatewayKernelManager, +) from jupyter_server.services.kernels.websocket import KernelWebsocketHandler from .utils import expected_http_error @@ -902,3 +910,89 @@ async def delete_kernel(jp_fetch, kernel_id): r = await jp_fetch("api", "kernels", kernel_id, method="DELETE") assert r.code == 204 assert r.reason == "No Content" + + +@pytest.fixture +def mock_channel_queue(): + queue = ChannelQueue("shell", MagicMock(), MagicMock()) + return queue + + +@pytest.fixture +def gateway_kernel_client(init_gateway, monkeypatch): + client = GatewayKernelClient("fake-kernel-id") + client._channel_queues = { + "shell": ChannelQueue("shell", MagicMock(), MagicMock()), + "iopub": ChannelQueue("iopub", MagicMock(), MagicMock()), + "stdin": ChannelQueue("stdin", MagicMock(), MagicMock()), + "hb": ChannelQueue("hb", MagicMock(), MagicMock()), + "control": ChannelQueue("control", MagicMock(), MagicMock()), + } + client._shell_channel = client._channel_queues["shell"] + client._iopub_channel = client._channel_queues["iopub"] + client._stdin_channel = client._channel_queues["stdin"] + client._hb_channel = client._channel_queues["hb"] + client._control_channel = client._channel_queues["control"] + return client + + +def fake_create_connection(*args, **kwargs): + return MagicMock() + + +async def test_gateway_kernel_client_start_and_stop_channels(gateway_kernel_client, monkeypatch): + monkeypatch.setattr("websocket.create_connection", fake_create_connection) + monkeypatch.setattr(gateway_kernel_client, "channel_socket", MagicMock()) + monkeypatch.setattr(gateway_kernel_client, "response_router", MagicMock()) + await gateway_kernel_client.start_channels() + gateway_kernel_client.stop_channels() + assert gateway_kernel_client._channels_stopped + + +# @pytest.mark.asyncio +async def test_gateway_kernel_client_execute_interactive(gateway_kernel_client, monkeypatch): + gateway_kernel_client.execute = MagicMock(return_value="msg-123") + + async def fake_shell_get_msg(timeout=None): + return {"parent_header": {"msg_id": "msg-123"}, "msg_type": "execute_reply"} + + gateway_kernel_client.shell_channel.get_msg = fake_shell_get_msg + + async def fake_iopub_get_msg(timeout=None): + await asyncio.sleep(0.01) + return { + "parent_header": {"msg_id": "msg-123"}, + "msg_type": "status", + "header": {"msg_type": "status"}, + "content": {"execution_state": "idle"}, + } + + gateway_kernel_client.iopub_channel.get_msg = fake_iopub_get_msg + + async def fake_stdin_get_msg(timeout=None): + await asyncio.sleep(0.01) + return {"parent_header": {"msg_id": "msg-123"}, "msg_type": "input_request"} + + gateway_kernel_client.stdin_channel.get_msg = fake_stdin_get_msg + output_msgs = [] + + async def output_hook(msg): + output_msgs.append(msg) + + stdin_msgs = [] + + async def stdin_hook(msg): + stdin_msgs.append(msg) + + reply = await gateway_kernel_client.execute_interactive( + "print(1)", output_hook=output_hook, stdin_hook=stdin_hook + ) + assert reply["msg_type"] == "execute_reply" + + +async def test_gateway_channel_queue_get_msg_with_response_router_finished( + mock_channel_queue, +): + mock_channel_queue.response_router_finished = True + with pytest.raises(RuntimeError): + await mock_channel_queue.get_msg(timeout=0.1)