diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 056381cc9..7335cc831 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -31,6 +31,8 @@ jobs: pip install --upgrade setuptools pip pip install --upgrade --upgrade-strategy eager -e .[test] pytest-cov codecov 'coverage<5' pip freeze + - name: Check types + run: mypy jupyter_client/manager.py jupyter_client/multikernelmanager.py jupyter_client/client.py jupyter_client/blocking/client.py jupyter_client/asynchronous/client.py jupyter_client/channels.py jupyter_client/session.py jupyter_client/adapter.py jupyter_client/connect.py jupyter_client/consoleapp.py jupyter_client/jsonutil.py jupyter_client/kernelapp.py jupyter_client/launcher.py - name: Run the tests run: py.test --cov jupyter_client -v jupyter_client - name: Code coverage diff --git a/jupyter_client/adapter.py b/jupyter_client/adapter.py index 94109dceb..e4e09a54c 100644 --- a/jupyter_client/adapter.py +++ b/jupyter_client/adapter.py @@ -5,10 +5,14 @@ import re import json +from typing import List, Tuple, Dict, Any from jupyter_client import protocol_version_info -def code_to_line(code, cursor_pos): +def code_to_line( + code: str, + cursor_pos: int +) -> Tuple[str, int]: """Turn a multiline code block and cursor position into a single line and new cursor position. @@ -29,14 +33,17 @@ def code_to_line(code, cursor_pos): _end_bracket = re.compile(r'\([^\(]*$', re.UNICODE) _identifier = re.compile(r'[a-z_][0-9a-z._]*', re.I|re.UNICODE) -def extract_oname_v4(code, cursor_pos): +def extract_oname_v4( + code: str, + cursor_pos: int +) -> str: """Reimplement token-finding logic from IPython 2.x javascript - + for adapting object_info_request from v5 to v4 """ - + line, _ = code_to_line(code, cursor_pos) - + oldline = line line = _match_bracket.sub('', line) while oldline != line: @@ -58,29 +65,44 @@ class Adapter(object): Override message_type(msg) methods to create adapters. """ - msg_type_map = {} + msg_type_map: Dict[str, str] = {} - def update_header(self, msg): + def update_header( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: return msg - def update_metadata(self, msg): + def update_metadata( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: return msg - def update_msg_type(self, msg): + def update_msg_type( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: header = msg['header'] msg_type = header['msg_type'] if msg_type in self.msg_type_map: msg['msg_type'] = header['msg_type'] = self.msg_type_map[msg_type] return msg - def handle_reply_status_error(self, msg): + def handle_reply_status_error( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: """This will be called *instead of* the regular handler on any reply with status != ok """ return msg - def __call__(self, msg): + def __call__( + self, + msg: Dict[str, Any] + ): msg = self.update_header(msg) msg = self.update_metadata(msg) msg = self.update_msg_type(msg) @@ -95,7 +117,9 @@ def __call__(self, msg): return self.handle_reply_status_error(msg) return handler(msg) -def _version_str_to_list(version): +def _version_str_to_list( + version: str +) -> List[int]: """convert a version string to a list of ints non-int segments are excluded @@ -121,14 +145,20 @@ class V5toV4(Adapter): 'inspect_reply' : 'object_info_reply', } - def update_header(self, msg): + def update_header( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: msg['header'].pop('version', None) msg['parent_header'].pop('version', None) return msg # shell channel - def kernel_info_reply(self, msg): + def kernel_info_reply( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: v4c = {} content = msg['content'] for key in ('language_version', 'protocol_version'): @@ -145,18 +175,27 @@ def kernel_info_reply(self, msg): msg['content'] = v4c return msg - def execute_request(self, msg): + def execute_request( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: content = msg['content'] content.setdefault('user_variables', []) return msg - def execute_reply(self, msg): + def execute_reply( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: content = msg['content'] content.setdefault('user_variables', {}) # TODO: handle payloads return msg - def complete_request(self, msg): + def complete_request( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: content = msg['content'] code = content['code'] cursor_pos = content['cursor_pos'] @@ -169,7 +208,10 @@ def complete_request(self, msg): new_content['cursor_pos'] = cursor_pos return msg - def complete_reply(self, msg): + def complete_reply( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: content = msg['content'] cursor_start = content.pop('cursor_start') cursor_end = content.pop('cursor_end') @@ -178,7 +220,10 @@ def complete_reply(self, msg): content.pop('metadata', None) return msg - def object_info_request(self, msg): + def object_info_request( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: content = msg['content'] code = content['code'] cursor_pos = content['cursor_pos'] @@ -189,19 +234,28 @@ def object_info_request(self, msg): new_content['detail_level'] = content['detail_level'] return msg - def object_info_reply(self, msg): + def object_info_reply( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: """inspect_reply can't be easily backward compatible""" msg['content'] = {'found' : False, 'oname' : 'unknown'} return msg # iopub channel - def stream(self, msg): + def stream( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: content = msg['content'] content['data'] = content.pop('text') return msg - def display_data(self, msg): + def display_data( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: content = msg['content'] content.setdefault("source", "display") data = content['data'] @@ -215,7 +269,10 @@ def display_data(self, msg): # stdin channel - def input_request(self, msg): + def input_request( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: msg['content'].pop('password', None) return msg @@ -227,7 +284,10 @@ class V4toV5(Adapter): # invert message renames above msg_type_map = {v:k for k,v in V5toV4.msg_type_map.items()} - def update_header(self, msg): + def update_header( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: msg['header']['version'] = self.version if msg['parent_header']: msg['parent_header']['version'] = self.version @@ -235,7 +295,10 @@ def update_header(self, msg): # shell channel - def kernel_info_reply(self, msg): + def kernel_info_reply( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: content = msg['content'] for key in ('protocol_version', 'ipython_version'): if key in content: @@ -257,7 +320,10 @@ def kernel_info_reply(self, msg): content['banner'] = '' return msg - def execute_request(self, msg): + def execute_request( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: content = msg['content'] user_variables = content.pop('user_variables', []) user_expressions = content.setdefault('user_expressions', {}) @@ -265,7 +331,10 @@ def execute_request(self, msg): user_expressions[v] = v return msg - def execute_reply(self, msg): + def execute_reply( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: content = msg['content'] user_expressions = content.setdefault('user_expressions', {}) user_variables = content.pop('user_variables', {}) @@ -281,7 +350,10 @@ def execute_reply(self, msg): return msg - def complete_request(self, msg): + def complete_request( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: old_content = msg['content'] new_content = msg['content'] = {} @@ -289,7 +361,10 @@ def complete_request(self, msg): new_content['cursor_pos'] = old_content['cursor_pos'] return msg - def complete_reply(self, msg): + def complete_reply( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: # complete_reply needs more context than we have to get cursor_start and end. # use special end=null to indicate current cursor position and negative offset # for start relative to the cursor. @@ -306,7 +381,10 @@ def complete_reply(self, msg): new_content['metadata'] = {} return msg - def inspect_request(self, msg): + def inspect_request( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: content = msg['content'] name = content['oname'] @@ -316,7 +394,10 @@ def inspect_request(self, msg): new_content['detail_level'] = content['detail_level'] return msg - def inspect_reply(self, msg): + def inspect_reply( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: """inspect_reply can't be easily backward compatible""" content = msg['content'] new_content = msg['content'] = {'status' : 'ok'} @@ -340,12 +421,18 @@ def inspect_reply(self, msg): # iopub channel - def stream(self, msg): + def stream( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: content = msg['content'] content['text'] = content.pop('data') return msg - def display_data(self, msg): + def display_data( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: content = msg['content'] content.pop("source", None) data = content['data'] @@ -359,13 +446,19 @@ def display_data(self, msg): # stdin channel - def input_request(self, msg): + def input_request( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: msg['content'].setdefault('password', False) return msg -def adapt(msg, to_version=protocol_version_info[0]): +def adapt( + msg: Dict[str, Any], + to_version: int =protocol_version_info[0] + ) -> Dict[str, Any]: """Adapt a single message to a target version Parameters diff --git a/jupyter_client/asynchronous/channels.py b/jupyter_client/asynchronous/channels.py deleted file mode 100644 index b6f49bd36..000000000 --- a/jupyter_client/asynchronous/channels.py +++ /dev/null @@ -1,82 +0,0 @@ -"""Async channels""" - -# Copyright (c) Jupyter Development Team. -# Distributed under the terms of the Modified BSD License. - -from queue import Queue, Empty - - -class ZMQSocketChannel(object): - """A ZMQ socket in an async API""" - session = None - socket = None - stream = None - _exiting = False - proxy_methods = [] - - def __init__(self, socket, session, loop=None): - """Create a channel. - - Parameters - ---------- - socket : :class:`zmq.asyncio.Socket` - The ZMQ socket to use. - session : :class:`session.Session` - The session to use. - loop - Unused here, for other implementations - """ - super().__init__() - - self.socket = socket - self.session = session - - async def _recv(self, **kwargs): - msg = await self.socket.recv_multipart(**kwargs) - ident,smsg = self.session.feed_identities(msg) - return self.session.deserialize(smsg) - - async def get_msg(self, timeout=None): - """ Gets a message if there is one that is ready. """ - if timeout is not None: - timeout *= 1000 # seconds to ms - ready = await self.socket.poll(timeout) - - if ready: - return await self._recv() - else: - raise Empty - - async def get_msgs(self): - """ Get all messages that are currently ready. """ - msgs = [] - while True: - try: - msgs.append(await self.get_msg()) - except Empty: - break - return msgs - - async def msg_ready(self): - """ Is there a message that has been received? """ - return bool(await self.socket.poll(timeout=0)) - - def close(self): - if self.socket is not None: - try: - self.socket.close(linger=0) - except Exception: - pass - self.socket = None - stop = close - - def is_alive(self): - return (self.socket is not None) - - def send(self, msg): - """Pass a message to the ZMQ socket to send - """ - self.session.send(self.socket, msg) - - def start(self): - pass diff --git a/jupyter_client/asynchronous/client.py b/jupyter_client/asynchronous/client.py index 1a21e3ac8..86fb8737e 100644 --- a/jupyter_client/asynchronous/client.py +++ b/jupyter_client/asynchronous/client.py @@ -2,59 +2,21 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. -from functools import partial -from getpass import getpass -from queue import Empty -import sys -import time +from traitlets import (Type, Instance) # type: ignore +from jupyter_client.channels import HBChannel, ZMQSocketChannel +from jupyter_client.client import KernelClient, reqrep -import zmq -import zmq.asyncio -import asyncio -from traitlets import (Type, Instance) -from jupyter_client.channels import HBChannel -from jupyter_client.client import KernelClient -from .channels import ZMQSocketChannel - - -def reqrep(meth, channel='shell'): - def wrapped(self, *args, **kwargs): +def wrapped(meth, channel): + def _(self, *args, **kwargs): reply = kwargs.pop('reply', False) timeout = kwargs.pop('timeout', None) msg_id = meth(self, *args, **kwargs) if not reply: return msg_id + return self._async_recv_reply(msg_id, timeout=timeout, channel=channel) + return _ - return self._recv_reply(msg_id, timeout=timeout, channel=channel) - - if not meth.__doc__: - # python -OO removes docstrings, - # so don't bother building the wrapped docstring - return wrapped - - basedoc, _ = meth.__doc__.split('Returns\n', 1) - parts = [basedoc.strip()] - if 'Parameters' not in basedoc: - parts.append(""" - Parameters - ---------- - """) - parts.append(""" - reply: bool (default: False) - Whether to wait for and return reply - timeout: float or None (default: None) - Timeout to use when waiting for a reply - - Returns - ------- - msg_id: str - The msg_id of the request sent, if reply=False (default) - reply: dict - The reply message for this request, if reply=True - """) - wrapped.__doc__ = '\n'.join(parts) - return wrapped class AsyncKernelClient(KernelClient): """A KernelClient with async APIs @@ -63,98 +25,16 @@ class AsyncKernelClient(KernelClient): raising :exc:`queue.Empty` if no message arrives within ``timeout`` seconds. """ - # The PyZMQ Context to use for communication with the kernel. - context = Instance(zmq.asyncio.Context) - def _context_default(self): - return zmq.asyncio.Context() - #-------------------------------------------------------------------------- # Channel proxy methods #-------------------------------------------------------------------------- - async def get_shell_msg(self, *args, **kwargs): - """Get a message from the shell channel""" - return await self.shell_channel.get_msg(*args, **kwargs) - - async def get_iopub_msg(self, *args, **kwargs): - """Get a message from the iopub channel""" - return await self.iopub_channel.get_msg(*args, **kwargs) - - async def get_stdin_msg(self, *args, **kwargs): - """Get a message from the stdin channel""" - return await self.stdin_channel.get_msg(*args, **kwargs) - - async def get_control_msg(self, *args, **kwargs): - """Get a message from the control channel""" - return await self.control_channel.get_msg(*args, **kwargs) - - @property - def hb_channel(self): - """Get the hb channel object for this kernel.""" - if self._hb_channel is None: - url = self._make_url('hb') - self.log.debug("connecting heartbeat channel to %s", url) - loop = asyncio.new_event_loop() - self._hb_channel = self.hb_channel_class( - self.context, self.session, url, loop - ) - return self._hb_channel - - async def wait_for_ready(self, timeout=None): - """Waits for a response when a client is blocked - - - Sets future time for timeout - - Blocks on shell channel until a message is received - - Exit if the kernel has died - - If client times out before receiving a message from the kernel, send RuntimeError - - Flush the IOPub channel - """ - if timeout is None: - abs_timeout = float('inf') - else: - abs_timeout = time.time() + timeout + get_shell_msg = KernelClient._async_get_shell_msg + get_iopub_msg = KernelClient._async_get_iopub_msg + get_stdin_msg = KernelClient._async_get_stdin_msg + get_control_msg = KernelClient._async_get_control_msg - from ..manager import KernelManager - if not isinstance(self.parent, KernelManager): - # This Client was not created by a KernelManager, - # so wait for kernel to become responsive to heartbeats - # before checking for kernel_info reply - while not self.is_alive(): - if time.time() > abs_timeout: - raise RuntimeError("Kernel didn't respond to heartbeats in %d seconds and timed out" % timeout) - await asyncio.sleep(0.2) - - # Wait for kernel info reply on shell channel - while True: - self.kernel_info() - try: - msg = await self.shell_channel.get_msg(timeout=1) - except Empty: - pass - else: - if msg['msg_type'] == 'kernel_info_reply': - # Checking that IOPub is connected. If it is not connected, start over. - try: - await self.iopub_channel.get_msg(timeout=0.2) - except Empty: - pass - else: - self._handle_kernel_info_reply(msg) - break - - if not await self.is_alive(): - raise RuntimeError('Kernel died before replying to kernel_info') - - # Check if current time is ready check time plus timeout - if time.time() > abs_timeout: - raise RuntimeError("Kernel didn't respond in %d seconds" % timeout) - - # Flush IOPub channel - while True: - try: - msg = await self.iopub_channel.get_msg(timeout=0.2) - except Empty: - break + wait_for_ready = KernelClient._async_wait_for_ready # The classes to use for the various channels shell_channel_class = Type(ZMQSocketChannel) @@ -164,232 +44,19 @@ async def wait_for_ready(self, timeout=None): control_channel_class = Type(ZMQSocketChannel) - async def _recv_reply(self, msg_id, timeout=None, channel='shell'): - """Receive and return the reply for a given request""" - if timeout is not None: - deadline = time.monotonic() + timeout - while True: - if timeout is not None: - timeout = max(0, deadline - time.monotonic()) - try: - if channel == 'control': - reply = await self.get_control_msg(timeout=timeout) - else: - reply = await self.get_shell_msg(timeout=timeout) - except Empty as e: - raise TimeoutError("Timeout waiting for reply") from e - if reply['parent_header'].get('msg_id') != msg_id: - # not my reply, someone may have forgotten to retrieve theirs - continue - return reply + _recv_reply = KernelClient._async_recv_reply # replies come on the shell channel - execute = reqrep(KernelClient.execute) - history = reqrep(KernelClient.history) - complete = reqrep(KernelClient.complete) - inspect = reqrep(KernelClient.inspect) - kernel_info = reqrep(KernelClient.kernel_info) - comm_info = reqrep(KernelClient.comm_info) - - # replies come on the control channel - shutdown = reqrep(KernelClient.shutdown, channel='control') - - - def _stdin_hook_default(self, msg): - """Handle an input request""" - content = msg['content'] - if content.get('password', False): - prompt = getpass - else: - prompt = input - - try: - raw_data = prompt(content["prompt"]) - except EOFError: - # turn EOFError into EOF character - raw_data = '\x04' - except KeyboardInterrupt: - sys.stdout.write('\n') - return - - # only send stdin reply if there *was not* another request - # or execution finished while we were reading. - if not (self.stdin_channel.msg_ready() or self.shell_channel.msg_ready()): - self.input(raw_data) - - def _output_hook_default(self, msg): - """Default hook for redisplaying plain-text output""" - msg_type = msg['header']['msg_type'] - content = msg['content'] - if msg_type == 'stream': - stream = getattr(sys, content['name']) - stream.write(content['text']) - elif msg_type in ('display_data', 'execute_result'): - sys.stdout.write(content['data'].get('text/plain', '')) - elif msg_type == 'error': - print('\n'.join(content['traceback']), file=sys.stderr) + execute = reqrep(wrapped, KernelClient._execute) + history = reqrep(wrapped, KernelClient._history) + complete = reqrep(wrapped, KernelClient._complete) + inspect = reqrep(wrapped, KernelClient._inspect) + kernel_info = reqrep(wrapped, KernelClient._kernel_info) + comm_info = reqrep(wrapped, KernelClient._comm_info) - def _output_hook_kernel(self, session, socket, parent_header, msg): - """Output hook when running inside an IPython kernel + is_alive = KernelClient._async_is_alive + execute_interactive = KernelClient._async_execute_interactive - adds rich output support. - """ - msg_type = msg['header']['msg_type'] - if msg_type in ('display_data', 'execute_result', 'error'): - session.send(socket, msg_type, msg['content'], parent=parent_header) - else: - self._output_hook_default(msg) - - async def is_alive(self): - """Is the kernel process still running?""" - from ..manager import KernelManager, AsyncKernelManager - if isinstance(self.parent, KernelManager): - # This KernelClient was created by a KernelManager, - # we can ask the parent KernelManager: - if isinstance(self.parent, AsyncKernelManager): - return await self.parent.is_alive() - return self.parent.is_alive() - if self._hb_channel is not None: - # We don't have access to the KernelManager, - # so we use the heartbeat. - return self._hb_channel.is_beating() - else: - # no heartbeat and not local, we can't tell if it's running, - # so naively return True - return True - - async def execute_interactive(self, code, silent=False, store_history=True, - user_expressions=None, allow_stdin=None, stop_on_error=True, - timeout=None, output_hook=None, stdin_hook=None, - ): - """Execute code in the kernel interactively - - Output will be redisplayed, and stdin prompts will be relayed as well. - If an IPython kernel is detected, rich output will be displayed. - - You can pass a custom output_hook callable that will be called - with every IOPub message that is produced instead of the default redisplay. - - Parameters - ---------- - code : str - A string of code in the kernel's language. - - silent : bool, optional (default False) - If set, the kernel will execute the code as quietly possible, and - will force store_history to be False. - - store_history : bool, optional (default True) - If set, the kernel will store command history. This is forced - to be False if silent is True. - - user_expressions : dict, optional - A dict mapping names to expressions to be evaluated in the user's - dict. The expression values are returned as strings formatted using - :func:`repr`. - - allow_stdin : bool, optional (default self.allow_stdin) - Flag for whether the kernel can send stdin requests to frontends. - - Some frontends (e.g. the Notebook) do not support stdin requests. - If raw_input is called from code executed from such a frontend, a - StdinNotImplementedError will be raised. - - stop_on_error: bool, optional (default True) - Flag whether to abort the execution queue, if an exception is encountered. - - timeout: float or None (default: None) - Timeout to use when waiting for a reply - - output_hook: callable(msg) - Function to be called with output messages. - If not specified, output will be redisplayed. - - stdin_hook: callable(msg) - Function to be called with stdin_request messages. - If not specified, input/getpass will be called. - - Returns - ------- - reply: dict - The reply message for this request - """ - if not self.iopub_channel.is_alive(): - raise RuntimeError("IOPub channel must be running to receive output") - if allow_stdin is None: - allow_stdin = self.allow_stdin - if allow_stdin and not self.stdin_channel.is_alive(): - raise RuntimeError("stdin channel must be running to allow input") - msg_id = await self.execute(code, - silent=silent, - store_history=store_history, - user_expressions=user_expressions, - allow_stdin=allow_stdin, - stop_on_error=stop_on_error, - ) - if stdin_hook is None: - stdin_hook = self._stdin_hook_default - if output_hook is None: - # detect IPython kernel - if 'IPython' in sys.modules: - from IPython import get_ipython - ip = get_ipython() - in_kernel = getattr(ip, 'kernel', False) - if in_kernel: - output_hook = partial( - self._output_hook_kernel, - ip.display_pub.session, - ip.display_pub.pub_socket, - ip.display_pub.parent_header, - ) - if output_hook is None: - # default: redisplay plain-text outputs - output_hook = self._output_hook_default - - # set deadline based on timeout - if timeout is not None: - deadline = time.monotonic() + timeout - else: - timeout_ms = None - - poller = zmq.Poller() - iopub_socket = self.iopub_channel.socket - poller.register(iopub_socket, zmq.POLLIN) - if allow_stdin: - stdin_socket = self.stdin_channel.socket - poller.register(stdin_socket, zmq.POLLIN) - else: - stdin_socket = None - - # wait for output and redisplay it - while True: - if timeout is not None: - timeout = max(0, deadline - time.monotonic()) - timeout_ms = 1e3 * timeout - events = dict(poller.poll(timeout_ms)) - if not events: - raise TimeoutError("Timeout waiting for output") - if stdin_socket in events: - req = await self.stdin_channel.get_msg(timeout=0) - stdin_hook(req) - continue - if iopub_socket not in events: - continue - - msg = await self.iopub_channel.get_msg(timeout=0) - - if msg['parent_header'].get('msg_id') != msg_id: - # not from my request - continue - output_hook(msg) - - # stop on idle - if msg['header']['msg_type'] == 'status' and \ - msg['content']['execution_state'] == 'idle': - break - - # output is done, get the reply - if timeout is not None: - timeout = max(0, deadline - time.monotonic()) - return await self._recv_reply(msg_id, timeout=timeout) + # replies come on the control channel + shutdown = reqrep(wrapped, KernelClient._shutdown, channel='control') diff --git a/jupyter_client/blocking/channels.py b/jupyter_client/blocking/channels.py deleted file mode 100644 index ab24b692d..000000000 --- a/jupyter_client/blocking/channels.py +++ /dev/null @@ -1,88 +0,0 @@ -"""Blocking channels - -Useful for test suites and blocking terminal interfaces. -""" - -# Copyright (c) Jupyter Development Team. -# Distributed under the terms of the Modified BSD License. - -from queue import Queue, Empty - - -class ZMQSocketChannel(object): - """A ZMQ socket in a simple blocking API""" - session = None - socket = None - stream = None - _exiting = False - proxy_methods = [] - - def __init__(self, socket, session, loop=None): - """Create a channel. - - Parameters - ---------- - socket : :class:`zmq.Socket` - The ZMQ socket to use. - session : :class:`session.Session` - The session to use. - loop - Unused here, for other implementations - """ - super().__init__() - - self.socket = socket - self.session = session - - def _recv(self, **kwargs): - msg = self.socket.recv_multipart(**kwargs) - ident,smsg = self.session.feed_identities(msg) - return self.session.deserialize(smsg) - - def get_msg(self, block=True, timeout=None): - """ Gets a message if there is one that is ready. """ - if block: - if timeout is not None: - timeout *= 1000 # seconds to ms - ready = self.socket.poll(timeout) - else: - ready = self.socket.poll(timeout=0) - - if ready: - return self._recv() - else: - raise Empty - - def get_msgs(self): - """ Get all messages that are currently ready. """ - msgs = [] - while True: - try: - msgs.append(self.get_msg(block=False)) - except Empty: - break - return msgs - - def msg_ready(self): - """ Is there a message that has been received? """ - return bool(self.socket.poll(timeout=0)) - - def close(self): - if self.socket is not None: - try: - self.socket.close(linger=0) - except Exception: - pass - self.socket = None - stop = close - - def is_alive(self): - return (self.socket is not None) - - def send(self, msg): - """Pass a message to the ZMQ socket to send - """ - self.session.send(self.socket, msg) - - def start(self): - pass diff --git a/jupyter_client/blocking/client.py b/jupyter_client/blocking/client.py index 5f11b798a..34dafdf43 100644 --- a/jupyter_client/blocking/client.py +++ b/jupyter_client/blocking/client.py @@ -5,58 +5,22 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. -from functools import partial -from getpass import getpass -from queue import Empty -import sys -import time +from traitlets import Type # type: ignore +from jupyter_client.channels import HBChannel, ZMQSocketChannel +from jupyter_client.client import KernelClient, reqrep +from ..utils import run_sync -import zmq -from time import monotonic -from traitlets import Type -from jupyter_client.channels import HBChannel -from jupyter_client.client import KernelClient -from .channels import ZMQSocketChannel - - -def reqrep(meth, channel='shell'): - def wrapped(self, *args, **kwargs): +def wrapped(meth, channel): + def _(self, *args, **kwargs): reply = kwargs.pop('reply', False) timeout = kwargs.pop('timeout', None) msg_id = meth(self, *args, **kwargs) if not reply: return msg_id + return run_sync(self._async_recv_reply)(msg_id, timeout=timeout, channel=channel) + return _ - return self._recv_reply(msg_id, timeout=timeout, channel=channel) - - if not meth.__doc__: - # python -OO removes docstrings, - # so don't bother building the wrapped docstring - return wrapped - - basedoc, _ = meth.__doc__.split('Returns\n', 1) - parts = [basedoc.strip()] - if 'Parameters' not in basedoc: - parts.append(""" - Parameters - ---------- - """) - parts.append(""" - reply: bool (default: False) - Whether to wait for and return reply - timeout: float or None (default: None) - Timeout to use when waiting for a reply - - Returns - ------- - msg_id: str - The msg_id of the request sent, if reply=False (default) - reply: dict - The reply message for this request, if reply=True - """) - wrapped.__doc__ = '\n'.join(parts) - return wrapped class BlockingKernelClient(KernelClient): """A KernelClient with blocking APIs @@ -65,61 +29,16 @@ class BlockingKernelClient(KernelClient): raising :exc:`queue.Empty` if no message arrives within ``timeout`` seconds. """ - def wait_for_ready(self, timeout=None): - """Waits for a response when a client is blocked + #-------------------------------------------------------------------------- + # Channel proxy methods + #-------------------------------------------------------------------------- - - Sets future time for timeout - - Blocks on shell channel until a message is received - - Exit if the kernel has died - - If client times out before receiving a message from the kernel, send RuntimeError - - Flush the IOPub channel - """ - if timeout is None: - abs_timeout = float('inf') - else: - abs_timeout = time.time() + timeout + get_shell_msg = run_sync(KernelClient._async_get_shell_msg) + get_iopub_msg = run_sync(KernelClient._async_get_iopub_msg) + get_stdin_msg = run_sync(KernelClient._async_get_stdin_msg) + get_control_msg = run_sync(KernelClient._async_get_control_msg) - from ..manager import KernelManager - if not isinstance(self.parent, KernelManager): - # This Client was not created by a KernelManager, - # so wait for kernel to become responsive to heartbeats - # before checking for kernel_info reply - while not self.is_alive(): - if time.time() > abs_timeout: - raise RuntimeError("Kernel didn't respond to heartbeats in %d seconds and timed out" % timeout) - time.sleep(0.2) - - # Wait for kernel info reply on shell channel - while True: - self.kernel_info() - try: - msg = self.shell_channel.get_msg(block=True, timeout=1) - except Empty: - pass - else: - if msg['msg_type'] == 'kernel_info_reply': - # Checking that IOPub is connected. If it is not connected, start over. - try: - self.iopub_channel.get_msg(block=True, timeout=0.2) - except Empty: - pass - else: - self._handle_kernel_info_reply(msg) - break - - if not self.is_alive(): - raise RuntimeError('Kernel died before replying to kernel_info') - - # Check if current time is ready check time plus timeout - if time.time() > abs_timeout: - raise RuntimeError("Kernel didn't respond in %d seconds" % timeout) - - # Flush IOPub channel - while True: - try: - msg = self.iopub_channel.get_msg(block=True, timeout=0.2) - except Empty: - break + wait_for_ready = run_sync(KernelClient._async_wait_for_ready) # The classes to use for the various channels shell_channel_class = Type(ZMQSocketChannel) @@ -129,216 +48,19 @@ def wait_for_ready(self, timeout=None): control_channel_class = Type(ZMQSocketChannel) - def _recv_reply(self, msg_id, timeout=None, channel='shell'): - """Receive and return the reply for a given request""" - if timeout is not None: - deadline = monotonic() + timeout - while True: - if timeout is not None: - timeout = max(0, deadline - monotonic()) - try: - if channel == 'control': - reply = self.get_control_msg(timeout=timeout) - else: - reply = self.get_shell_msg(timeout=timeout) - except Empty as e: - raise TimeoutError("Timeout waiting for reply") from e - if reply['parent_header'].get('msg_id') != msg_id: - # not my reply, someone may have forgotten to retrieve theirs - continue - return reply + _recv_reply = run_sync(KernelClient._async_recv_reply) # replies come on the shell channel - execute = reqrep(KernelClient.execute) - history = reqrep(KernelClient.history) - complete = reqrep(KernelClient.complete) - inspect = reqrep(KernelClient.inspect) - kernel_info = reqrep(KernelClient.kernel_info) - comm_info = reqrep(KernelClient.comm_info) - - # replies come on the control channel - shutdown = reqrep(KernelClient.shutdown, channel='control') - - - def _stdin_hook_default(self, msg): - """Handle an input request""" - content = msg['content'] - if content.get('password', False): - prompt = getpass - else: - prompt = input - - try: - raw_data = prompt(content["prompt"]) - except EOFError: - # turn EOFError into EOF character - raw_data = '\x04' - except KeyboardInterrupt: - sys.stdout.write('\n') - return - - # only send stdin reply if there *was not* another request - # or execution finished while we were reading. - if not (self.stdin_channel.msg_ready() or self.shell_channel.msg_ready()): - self.input(raw_data) - - def _output_hook_default(self, msg): - """Default hook for redisplaying plain-text output""" - msg_type = msg['header']['msg_type'] - content = msg['content'] - if msg_type == 'stream': - stream = getattr(sys, content['name']) - stream.write(content['text']) - elif msg_type in ('display_data', 'execute_result'): - sys.stdout.write(content['data'].get('text/plain', '')) - elif msg_type == 'error': - print('\n'.join(content['traceback']), file=sys.stderr) - - def _output_hook_kernel(self, session, socket, parent_header, msg): - """Output hook when running inside an IPython kernel - - adds rich output support. - """ - msg_type = msg['header']['msg_type'] - if msg_type in ('display_data', 'execute_result', 'error'): - session.send(socket, msg_type, msg['content'], parent=parent_header) - else: - self._output_hook_default(msg) - - def execute_interactive(self, code, silent=False, store_history=True, - user_expressions=None, allow_stdin=None, stop_on_error=True, - timeout=None, output_hook=None, stdin_hook=None, - ): - """Execute code in the kernel interactively - - Output will be redisplayed, and stdin prompts will be relayed as well. - If an IPython kernel is detected, rich output will be displayed. + execute = reqrep(wrapped, KernelClient._execute) + history = reqrep(wrapped, KernelClient._history) + complete = reqrep(wrapped, KernelClient._complete) + inspect = reqrep(wrapped, KernelClient._inspect) + kernel_info = reqrep(wrapped, KernelClient._kernel_info) + comm_info = reqrep(wrapped, KernelClient._comm_info) - You can pass a custom output_hook callable that will be called - with every IOPub message that is produced instead of the default redisplay. + is_alive = run_sync(KernelClient._async_is_alive) + execute_interactive = run_sync(KernelClient._async_execute_interactive) - .. versionadded:: 5.0 - - Parameters - ---------- - code : str - A string of code in the kernel's language. - - silent : bool, optional (default False) - If set, the kernel will execute the code as quietly possible, and - will force store_history to be False. - - store_history : bool, optional (default True) - If set, the kernel will store command history. This is forced - to be False if silent is True. - - user_expressions : dict, optional - A dict mapping names to expressions to be evaluated in the user's - dict. The expression values are returned as strings formatted using - :func:`repr`. - - allow_stdin : bool, optional (default self.allow_stdin) - Flag for whether the kernel can send stdin requests to frontends. - - Some frontends (e.g. the Notebook) do not support stdin requests. - If raw_input is called from code executed from such a frontend, a - StdinNotImplementedError will be raised. - - stop_on_error: bool, optional (default True) - Flag whether to abort the execution queue, if an exception is encountered. - - timeout: float or None (default: None) - Timeout to use when waiting for a reply - - output_hook: callable(msg) - Function to be called with output messages. - If not specified, output will be redisplayed. - - stdin_hook: callable(msg) - Function to be called with stdin_request messages. - If not specified, input/getpass will be called. - - Returns - ------- - reply: dict - The reply message for this request - """ - if not self.iopub_channel.is_alive(): - raise RuntimeError("IOPub channel must be running to receive output") - if allow_stdin is None: - allow_stdin = self.allow_stdin - if allow_stdin and not self.stdin_channel.is_alive(): - raise RuntimeError("stdin channel must be running to allow input") - msg_id = self.execute(code, - silent=silent, - store_history=store_history, - user_expressions=user_expressions, - allow_stdin=allow_stdin, - stop_on_error=stop_on_error, - ) - if stdin_hook is None: - stdin_hook = self._stdin_hook_default - if output_hook is None: - # detect IPython kernel - if 'IPython' in sys.modules: - from IPython import get_ipython - ip = get_ipython() - in_kernel = getattr(ip, 'kernel', False) - if in_kernel: - output_hook = partial( - self._output_hook_kernel, - ip.display_pub.session, - ip.display_pub.pub_socket, - ip.display_pub.parent_header, - ) - if output_hook is None: - # default: redisplay plain-text outputs - output_hook = self._output_hook_default - - # set deadline based on timeout - if timeout is not None: - deadline = monotonic() + timeout - else: - timeout_ms = None - - poller = zmq.Poller() - iopub_socket = self.iopub_channel.socket - poller.register(iopub_socket, zmq.POLLIN) - if allow_stdin: - stdin_socket = self.stdin_channel.socket - poller.register(stdin_socket, zmq.POLLIN) - else: - stdin_socket = None - - # wait for output and redisplay it - while True: - if timeout is not None: - timeout = max(0, deadline - monotonic()) - timeout_ms = 1e3 * timeout - events = dict(poller.poll(timeout_ms)) - if not events: - raise TimeoutError("Timeout waiting for output") - if stdin_socket in events: - req = self.stdin_channel.get_msg(timeout=0) - stdin_hook(req) - continue - if iopub_socket not in events: - continue - - msg = self.iopub_channel.get_msg(timeout=0) - - if msg['parent_header'].get('msg_id') != msg_id: - # not from my request - continue - output_hook(msg) - - # stop on idle - if msg['header']['msg_type'] == 'status' and \ - msg['content']['execution_state'] == 'idle': - break - - # output is done, get the reply - if timeout is not None: - timeout = max(0, deadline - monotonic()) - return self._recv_reply(msg_id, timeout=timeout) + # replies come on the control channel + shutdown = reqrep(wrapped, KernelClient._shutdown, channel='control') diff --git a/jupyter_client/channels.py b/jupyter_client/channels.py index 90746e202..a94e959fa 100644 --- a/jupyter_client/channels.py +++ b/jupyter_client/channels.py @@ -8,8 +8,11 @@ from threading import Thread, Event import time import asyncio +from queue import Empty +import typing as t import zmq +import zmq.asyncio # import ZMQError in top-level namespace, to avoid ugly attribute-error messages # during garbage collection of threads at exit: from zmq import ZMQError @@ -17,6 +20,7 @@ from jupyter_client import protocol_version_info from .channelsabc import HBChannelABC +from .session import Session #----------------------------------------------------------------------------- # Constants and exceptions @@ -34,24 +38,27 @@ class HBChannel(Thread): this channel, the kernel manager will ensure that it is paused and un-paused as appropriate. """ - context = None session = None socket = None address = None _exiting = False - time_to_dead = 1. - poller = None + time_to_dead: float = 1. _running = None _pause = None _beating = None - def __init__(self, context=None, session=None, address=None, loop=None): + def __init__( + self, + context: zmq.asyncio.Context, + session: t.Optional[Session] = None, + address: t.Union[t.Tuple[str, int], str] = '' + ): """Create the heartbeat monitor thread. Parameters ---------- - context : :class:`zmq.Context` + context : :class:`zmq.asyncio.Context` The ZMQ context to use. session : :class:`session.Session` The session to use. @@ -61,16 +68,16 @@ def __init__(self, context=None, session=None, address=None, loop=None): super().__init__() self.daemon = True - self.loop = loop - self.context = context self.session = session if isinstance(address, tuple): if address[1] == 0: message = 'The port number for a channel cannot be 0.' raise InvalidPortNumber(message) - address = "tcp://%s:%i" % address - self.address = address + address_str = "tcp://%s:%i" % address + else: + address_str = address + self.address = address_str # running is False until `.start()` is called self._running = False @@ -81,13 +88,13 @@ def __init__(self, context=None, session=None, address=None, loop=None): @staticmethod @atexit.register - def _notice_exit(): + def _notice_exit() -> None: # Class definitions can be torn down during interpreter shutdown. # We only need to set _exiting flag if this hasn't happened. if HBChannel is not None: HBChannel._exiting = True - def _create_socket(self): + def _create_socket(self) -> None: if self.socket is not None: # close previous socket, before opening a new one self.poller.unregister(self.socket) @@ -98,7 +105,10 @@ def _create_socket(self): self.poller.register(self.socket, zmq.POLLIN) - def _poll(self, start_time): + def _poll( + self, + start_time: float + ) -> t.List[t.Any]: """poll for heartbeat replies until we reach self.time_to_dead. Ignores interrupts, and returns the result of poll(), which @@ -112,7 +122,7 @@ def _poll(self, start_time): events = [] while True: try: - events = self.poller.poll(1000 * until_dead) + events = self.poller.poll(int(1000 * until_dead)) except ZMQError as e: if e.errno == errno.EINTR: # ignore interrupts during heartbeat @@ -131,13 +141,17 @@ def _poll(self, start_time): break return events - def run(self): + def run(self) -> None: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self._async_run()) + + async def _async_run(self) -> None: """The thread's main activity. Call start() instead.""" - if self.loop is not None: - asyncio.set_event_loop(self.loop) self._create_socket() self._running = True self._beating = True + assert self.socket is not None while self._running: if self._pause: @@ -148,13 +162,13 @@ def run(self): since_last_heartbeat = 0.0 # no need to catch EFSM here, because the previous event was # either a recv or connect, which cannot be followed by EFSM - self.socket.send(b'ping') + await self.socket.send(b'ping') request_time = time.time() ready = self._poll(request_time) if ready: self._beating = True # the poll above guarantees we have something to recv - self.socket.recv() + await self.socket.recv() # sleep the remainder of the cycle remainder = self.time_to_dead - (time.time() - request_time) if remainder > 0: @@ -169,29 +183,29 @@ def run(self): self._create_socket() continue - def pause(self): + def pause(self) -> None: """Pause the heartbeat.""" self._pause = True - def unpause(self): + def unpause(self) -> None: """Unpause the heartbeat.""" self._pause = False - def is_beating(self): + def is_beating(self) -> bool: """Is the heartbeat running and responsive (and not paused).""" if self.is_alive() and not self._pause and self._beating: return True else: return False - def stop(self): + def stop(self) -> None: """Stop the channel's event loop and join its thread.""" self._running = False self._exit.set() self.join() self.close() - def close(self): + def close(self) -> None: if self.socket is not None: try: self.socket.close(linger=0) @@ -199,7 +213,10 @@ def close(self): pass self.socket = None - def call_handlers(self, since_last_heartbeat): + def call_handlers( + self, + since_last_heartbeat: float + ) -> None: """This method is called in the ioloop thread when a message arrives. Subclasses should override this method to handle incoming messages. @@ -211,3 +228,90 @@ def call_handlers(self, since_last_heartbeat): HBChannelABC.register(HBChannel) + + +class ZMQSocketChannel(object): + """A ZMQ socket in an async API""" + + def __init__( + self, + socket: zmq.sugar.socket.Socket, + session: Session, + loop: t.Any = None + ) -> None: + """Create a channel. + + Parameters + ---------- + socket : :class:`zmq.asyncio.Socket` + The ZMQ socket to use. + session : :class:`session.Session` + The session to use. + loop + Unused here, for other implementations + """ + super().__init__() + + self.socket: t.Optional[zmq.sugar.socket.Socket] = socket + self.session = session + + async def _recv(self, **kwargs) -> t.Dict[str, t.Any]: + assert self.socket is not None + msg = await self.socket.recv_multipart(**kwargs) + ident, smsg = self.session.feed_identities(msg) + return self.session.deserialize(smsg) + + async def get_msg( + self, + timeout: t.Optional[float] = None + ) -> t.Dict[str, t.Any]: + """ Gets a message if there is one that is ready. """ + if timeout is not None: + timeout *= 1000 # seconds to ms + assert self.socket is not None + ready = await self.socket.poll(timeout) + + if ready: + res = await self._recv() + return res + else: + raise Empty + + async def get_msgs(self) -> t.List[t.Dict[str, t.Any]]: + """ Get all messages that are currently ready. """ + msgs = [] + while True: + try: + msgs.append(await self.get_msg()) + except Empty: + break + return msgs + + async def msg_ready(self) -> bool: + """ Is there a message that has been received? """ + assert self.socket is not None + return bool(await self.socket.poll(timeout=0)) + + def close(self) -> None: + if self.socket is not None: + try: + self.socket.close(linger=0) + except Exception: + pass + self.socket = None + stop = close + + def is_alive(self) -> bool: + return (self.socket is not None) + + def send( + self, + msg: t.Dict[str, t.Any] + ) -> None: + """Pass a message to the ZMQ socket to send + """ + assert self.socket is not None + self.session.send(self.socket, msg) + + def start(self) -> None: + pass diff --git a/jupyter_client/client.py b/jupyter_client/client.py index 760ac5266..b4ba3f004 100644 --- a/jupyter_client/client.py +++ b/jupyter_client/client.py @@ -3,33 +3,82 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +import sys +import asyncio +import time +from functools import partial +from getpass import getpass +from queue import Empty +import socket +import typing as t + from jupyter_client.channels import major_protocol_version import zmq +import zmq.asyncio -from traitlets import ( - Any, Instance, Type, +from traitlets import ( # type: ignore + Any, Instance, Type, default ) from .channelsabc import (ChannelABC, HBChannelABC) from .clientabc import KernelClientABC from .connect import ConnectionFileMixin +from .session import Session +from .utils import ensure_async # some utilities to validate message structure, these might get moved elsewhere # if they prove to have more generic utility -def validate_string_dict(dct): +def validate_string_dict( + dct: t.Dict[str, str] +) -> None: """Validate that the input is a dict with string keys and values. Raises ValueError if not.""" - for k,v in dct.items(): + for k, v in dct.items(): if not isinstance(k, str): raise ValueError('key %r in dict must be a string' % k) if not isinstance(v, str): raise ValueError('value %r in dict must be a string' % v) +def reqrep( + wrapped: t.Callable, + meth: t.Callable, + channel: str = 'shell' +) -> t.Callable: + wrapped = wrapped(meth, channel) + if not meth.__doc__: + # python -OO removes docstrings, + # so don't bother building the wrapped docstring + return wrapped + + basedoc, _ = meth.__doc__.split('Returns\n', 1) + parts = [basedoc.strip()] + if 'Parameters' not in basedoc: + parts.append(""" + Parameters + ---------- + """) + parts.append(""" + reply: bool (default: False) + Whether to wait for and return reply + timeout: float or None (default: None) + Timeout to use when waiting for a reply + + Returns + ------- + msg_id: str + The msg_id of the request sent, if reply=False (default) + reply: dict + The reply message for this request, if reply=True + """) + wrapped.__doc__ = '\n'.join(parts) + return wrapped + + class KernelClient(ConnectionFileMixin): """Communicates with a single kernel on any host via zmq channels. @@ -48,9 +97,9 @@ class KernelClient(ConnectionFileMixin): """ # The PyZMQ Context to use for communication with the kernel. - context = Instance(zmq.Context) - def _context_default(self): - return zmq.Context() + context = Instance(zmq.asyncio.Context) + def _context_default(self) -> zmq.asyncio.Context: + return zmq.asyncio.Context() # The classes to use for the various channels shell_channel_class = Type(ChannelABC) @@ -67,33 +116,181 @@ def _context_default(self): _control_channel = Any() # flag for whether execute requests should be allowed to call raw_input: - allow_stdin = True + allow_stdin: bool = True #-------------------------------------------------------------------------- # Channel proxy methods #-------------------------------------------------------------------------- - def get_shell_msg(self, *args, **kwargs): + async def _async_get_shell_msg(self, *args, **kwargs) -> t.Dict[str, t.Any]: """Get a message from the shell channel""" - return self.shell_channel.get_msg(*args, **kwargs) + return await self.shell_channel.get_msg(*args, **kwargs) - def get_iopub_msg(self, *args, **kwargs): + async def _async_get_iopub_msg(self, *args, **kwargs) -> t.Dict[str, t.Any]: """Get a message from the iopub channel""" - return self.iopub_channel.get_msg(*args, **kwargs) + return await self.iopub_channel.get_msg(*args, **kwargs) - def get_stdin_msg(self, *args, **kwargs): + async def _async_get_stdin_msg(self, *args, **kwargs) -> t.Dict[str, t.Any]: """Get a message from the stdin channel""" - return self.stdin_channel.get_msg(*args, **kwargs) + return await self.stdin_channel.get_msg(*args, **kwargs) - def get_control_msg(self, *args, **kwargs): + async def _async_get_control_msg(self, *args, **kwargs) -> t.Dict[str, t.Any]: """Get a message from the control channel""" - return self.control_channel.get_msg(*args, **kwargs) + return await self.control_channel.get_msg(*args, **kwargs) + + async def _async_wait_for_ready( + self, + timeout: t.Optional[float] = None + ) -> None: + """Waits for a response when a client is blocked + + - Sets future time for timeout + - Blocks on shell channel until a message is received + - Exit if the kernel has died + - If client times out before receiving a message from the kernel, send RuntimeError + - Flush the IOPub channel + """ + if timeout is None: + timeout = float('inf') + abs_timeout = time.time() + timeout + + from .manager import KernelManager + if not isinstance(self.parent, KernelManager): + # This Client was not created by a KernelManager, + # so wait for kernel to become responsive to heartbeats + # before checking for kernel_info reply + while not await ensure_async(self.is_alive()): + if time.time() > abs_timeout: + raise RuntimeError("Kernel didn't respond to heartbeats in %d seconds and timed out" % timeout) + await asyncio.sleep(0.2) + + # Wait for kernel info reply on shell channel + while True: + self._kernel_info() + try: + msg = await self.shell_channel.get_msg(timeout=1) + except Empty: + pass + else: + if msg['msg_type'] == 'kernel_info_reply': + # Checking that IOPub is connected. If it is not connected, start over. + try: + await self.iopub_channel.get_msg(timeout=0.2) + except Empty: + pass + else: + self._handle_kernel_info_reply(msg) + break + + if not await ensure_async(self.is_alive()): + raise RuntimeError('Kernel died before replying to kernel_info') + + # Check if current time is ready check time plus timeout + if time.time() > abs_timeout: + raise RuntimeError("Kernel didn't respond in %d seconds" % timeout) + + # Flush IOPub channel + while True: + try: + msg = await self.iopub_channel.get_msg(timeout=0.2) + except Empty: + break + + async def _async_recv_reply( + self, + msg_id: str, + timeout: t.Optional[float] = None, + channel: str = 'shell' + ) -> t.Dict[str, t.Any]: + """Receive and return the reply for a given request""" + if timeout is not None: + deadline = time.monotonic() + timeout + while True: + if timeout is not None: + timeout = max(0, deadline - time.monotonic()) + try: + if channel == 'control': + reply = await self._async_get_control_msg(timeout=timeout) + else: + reply = await self._async_get_shell_msg(timeout=timeout) + except Empty as e: + raise TimeoutError("Timeout waiting for reply") from e + if reply['parent_header'].get('msg_id') != msg_id: + # not my reply, someone may have forgotten to retrieve theirs + continue + return reply + + + def _stdin_hook_default( + self, + msg: t.Dict[str, t.Any] + ) -> None: + """Handle an input request""" + content = msg['content'] + if content.get('password', False): + prompt = getpass + else: + prompt = input # type: ignore + + try: + raw_data = prompt(content["prompt"]) + except EOFError: + # turn EOFError into EOF character + raw_data = '\x04' + except KeyboardInterrupt: + sys.stdout.write('\n') + return + + # only send stdin reply if there *was not* another request + # or execution finished while we were reading. + if not (self.stdin_channel.msg_ready() or self.shell_channel.msg_ready()): + self.input(raw_data) + + def _output_hook_default( + self, + msg: t.Dict[str, t.Any] + ) -> None: + """Default hook for redisplaying plain-text output""" + msg_type = msg['header']['msg_type'] + content = msg['content'] + if msg_type == 'stream': + stream = getattr(sys, content['name']) + stream.write(content['text']) + elif msg_type in ('display_data', 'execute_result'): + sys.stdout.write(content['data'].get('text/plain', '')) + elif msg_type == 'error': + print('\n'.join(content['traceback']), file=sys.stderr) + + def _output_hook_kernel( + self, + session: Session, + socket: zmq.sugar.socket.Socket, + parent_header, + msg: t.Dict[str, t.Any] + ) -> None: + """Output hook when running inside an IPython kernel + + adds rich output support. + """ + msg_type = msg['header']['msg_type'] + if msg_type in ('display_data', 'execute_result', 'error'): + session.send(socket, msg_type, msg['content'], parent=parent_header) + else: + self._output_hook_default(msg) + #-------------------------------------------------------------------------- # Channel management methods #-------------------------------------------------------------------------- - def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, control=True): + def start_channels( + self, + shell: bool = True, + iopub: bool = True, + stdin: bool = True, + hb: bool = True, + control: bool = True + ) -> None: """Starts the channels for this kernel. This will create the channels if they do not exist and then start @@ -116,7 +313,7 @@ def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, control=Tr if control: self.control_channel.start() - def stop_channels(self): + def stop_channels(self) -> None: """Stops all the running channels for this kernel. This stops their event loops and joins their threads. @@ -133,7 +330,7 @@ def stop_channels(self): self.control_channel.stop() @property - def channels_running(self): + def channels_running(self) -> bool: """Are any of the channels created and running?""" return (self.shell_channel.is_alive() or self.iopub_channel.is_alive() or self.stdin_channel.is_alive() or self.hb_channel.is_alive() or @@ -142,7 +339,7 @@ def channels_running(self): ioloop = None # Overridden in subclasses that use pyzmq event loop @property - def shell_channel(self): + def shell_channel(self) -> t.Any: """Get the shell channel object for this kernel.""" if self._shell_channel is None: url = self._make_url('shell') @@ -154,7 +351,7 @@ def shell_channel(self): return self._shell_channel @property - def iopub_channel(self): + def iopub_channel(self) -> t.Any: """Get the iopub channel object for this kernel.""" if self._iopub_channel is None: url = self._make_url('iopub') @@ -166,7 +363,7 @@ def iopub_channel(self): return self._iopub_channel @property - def stdin_channel(self): + def stdin_channel(self) -> t.Any: """Get the stdin channel object for this kernel.""" if self._stdin_channel is None: url = self._make_url('stdin') @@ -178,7 +375,7 @@ def stdin_channel(self): return self._stdin_channel @property - def hb_channel(self): + def hb_channel(self) -> t.Any: """Get the hb channel object for this kernel.""" if self._hb_channel is None: url = self._make_url('hb') @@ -189,7 +386,7 @@ def hb_channel(self): return self._hb_channel @property - def control_channel(self): + def control_channel(self) -> t.Any: """Get the control channel object for this kernel.""" if self._control_channel is None: url = self._make_url('control') @@ -200,13 +397,13 @@ def control_channel(self): ) return self._control_channel - def is_alive(self): + async def _async_is_alive(self) -> bool: """Is the kernel process still running?""" - from .manager import KernelManager + from .manager import KernelManager, AsyncKernelManager if isinstance(self.parent, KernelManager): # This KernelClient was created by a KernelManager, # we can ask the parent KernelManager: - return self.parent.is_alive() + return await ensure_async(self.parent.is_alive()) if self._hb_channel is not None: # We don't have access to the KernelManager, # so we use the heartbeat. @@ -217,9 +414,162 @@ def is_alive(self): return True + async def _async_execute_interactive( + self, + code: str, + silent: bool = False, + store_history: bool = True, + user_expressions: t.Optional[t.Dict[str, t.Any]] = None, + allow_stdin: t.Optional[bool] = None, + stop_on_error: bool = True, + timeout: t.Optional[float] = None, + output_hook: t.Optional[t.Callable] = None, + stdin_hook: t.Optional[t.Callable] =None, + ) -> t.Dict[str, t.Any]: + """Execute code in the kernel interactively + + Output will be redisplayed, and stdin prompts will be relayed as well. + If an IPython kernel is detected, rich output will be displayed. + + You can pass a custom output_hook callable that will be called + with every IOPub message that is produced instead of the default redisplay. + + .. versionadded:: 5.0 + + Parameters + ---------- + code : str + A string of code in the kernel's language. + + silent : bool, optional (default False) + If set, the kernel will execute the code as quietly possible, and + will force store_history to be False. + + store_history : bool, optional (default True) + If set, the kernel will store command history. This is forced + to be False if silent is True. + + user_expressions : dict, optional + A dict mapping names to expressions to be evaluated in the user's + dict. The expression values are returned as strings formatted using + :func:`repr`. + + allow_stdin : bool, optional (default self.allow_stdin) + Flag for whether the kernel can send stdin requests to frontends. + + Some frontends (e.g. the Notebook) do not support stdin requests. + If raw_input is called from code executed from such a frontend, a + StdinNotImplementedError will be raised. + + stop_on_error: bool, optional (default True) + Flag whether to abort the execution queue, if an exception is encountered. + + timeout: float or None (default: None) + Timeout to use when waiting for a reply + + output_hook: callable(msg) + Function to be called with output messages. + If not specified, output will be redisplayed. + + stdin_hook: callable(msg) + Function to be called with stdin_request messages. + If not specified, input/getpass will be called. + + Returns + ------- + reply: dict + The reply message for this request + """ + if not self.iopub_channel.is_alive(): + raise RuntimeError("IOPub channel must be running to receive output") + if allow_stdin is None: + allow_stdin = self.allow_stdin + if allow_stdin and not self.stdin_channel.is_alive(): + raise RuntimeError("stdin channel must be running to allow input") + msg_id = self._execute(code, + silent=silent, + store_history=store_history, + user_expressions=user_expressions, + allow_stdin=allow_stdin, + stop_on_error=stop_on_error, + ) + if stdin_hook is None: + stdin_hook = self._stdin_hook_default + if output_hook is None: + # detect IPython kernel + if 'IPython' in sys.modules: + from IPython import get_ipython # type: ignore + ip = get_ipython() + in_kernel = getattr(ip, 'kernel', False) + if in_kernel: + output_hook = partial( + self._output_hook_kernel, + ip.display_pub.session, + ip.display_pub.pub_socket, + ip.display_pub.parent_header, + ) + if output_hook is None: + # default: redisplay plain-text outputs + output_hook = self._output_hook_default + + # set deadline based on timeout + if timeout is not None: + deadline = time.monotonic() + timeout + else: + timeout_ms = None + + poller = zmq.Poller() + iopub_socket = self.iopub_channel.socket + poller.register(iopub_socket, zmq.POLLIN) + if allow_stdin: + stdin_socket = self.stdin_channel.socket + poller.register(stdin_socket, zmq.POLLIN) + else: + stdin_socket = None + + # wait for output and redisplay it + while True: + if timeout is not None: + timeout = max(0, deadline - time.monotonic()) + timeout_ms = int(1000 * timeout) + events = dict(poller.poll(timeout_ms)) + if not events: + raise TimeoutError("Timeout waiting for output") + if stdin_socket in events: + req = await self.stdin_channel.get_msg(timeout=0) + stdin_hook(req) + continue + if iopub_socket not in events: + continue + + msg = await self.iopub_channel.get_msg(timeout=0) + + if msg['parent_header'].get('msg_id') != msg_id: + # not from my request + continue + output_hook(msg) + + # stop on idle + if msg['header']['msg_type'] == 'status' and \ + msg['content']['execution_state'] == 'idle': + break + + # output is done, get the reply + if timeout is not None: + timeout = max(0, deadline - time.monotonic()) + return await self._async_recv_reply(msg_id, timeout=timeout) + + # Methods to send specific messages on channels - def execute(self, code, silent=False, store_history=True, - user_expressions=None, allow_stdin=None, stop_on_error=True): + def _execute( + self, + code: str, + silent: bool = False, + store_history: bool = True, + user_expressions: t.Optional[t.Dict[str, t.Any]] = None, + allow_stdin: t.Optional[bool] = None, + stop_on_error: bool = True + ) -> str: """Execute code in the kernel. Parameters @@ -275,7 +625,11 @@ def execute(self, code, silent=False, store_history=True, self.shell_channel.send(msg) return msg['header']['msg_id'] - def complete(self, code, cursor_pos=None): + def _complete( + self, + code: str, + cursor_pos: t.Optional[int] = None + ) -> str: """Tab complete text in the kernel's namespace. Parameters @@ -298,7 +652,12 @@ def complete(self, code, cursor_pos=None): self.shell_channel.send(msg) return msg['header']['msg_id'] - def inspect(self, code, cursor_pos=None, detail_level=0): + def _inspect( + self, + code: str, + cursor_pos: t.Optional[int] = None, + detail_level: int = 0 + ) -> str: """Get metadata information about an object in the kernel's namespace. It is up to the kernel to determine the appropriate object to inspect. @@ -327,7 +686,13 @@ def inspect(self, code, cursor_pos=None, detail_level=0): self.shell_channel.send(msg) return msg['header']['msg_id'] - def history(self, raw=True, output=False, hist_access_type='range', **kwargs): + def _history( + self, + raw: bool = True, + output: bool = False, + hist_access_type: str = 'range', + **kwargs + ) -> str: """Get entries from the kernel's history list. Parameters @@ -368,7 +733,7 @@ def history(self, raw=True, output=False, hist_access_type='range', **kwargs): self.shell_channel.send(msg) return msg['header']['msg_id'] - def kernel_info(self): + def _kernel_info(self) -> str: """Request kernel info Returns @@ -379,7 +744,10 @@ def kernel_info(self): self.shell_channel.send(msg) return msg['header']['msg_id'] - def comm_info(self, target_name=None): + def _comm_info( + self, + target_name: t.Optional[str] = None + ) -> str: """Request comm info Returns @@ -394,7 +762,10 @@ def comm_info(self, target_name=None): self.shell_channel.send(msg) return msg['header']['msg_id'] - def _handle_kernel_info_reply(self, msg): + def _handle_kernel_info_reply( + self, + msg: t.Dict[str, t.Any] + ) -> None: """handle kernel info reply sets protocol adaptation version. This might @@ -404,13 +775,19 @@ def _handle_kernel_info_reply(self, msg): if adapt_version != major_protocol_version: self.session.adapt_version = adapt_version - def is_complete(self, code): + def is_complete( + self, + code: str + ) -> str: """Ask the kernel whether some code is complete and ready to execute.""" msg = self.session.msg('is_complete_request', {'code': code}) self.shell_channel.send(msg) return msg['header']['msg_id'] - def input(self, string): + def input( + self, + string: str + ) -> None: """Send a string of raw input to the kernel. This should only be called in response to the kernel sending an @@ -420,7 +797,10 @@ def input(self, string): msg = self.session.msg('input_reply', content) self.stdin_channel.send(msg) - def shutdown(self, restart=False): + def _shutdown( + self, + restart: bool = False + ) -> str: """Request an immediate kernel shutdown on the control channel. Upon receipt of the (empty) reply, client code can safely assume that diff --git a/jupyter_client/connect.py b/jupyter_client/connect.py index e31f9ec1d..2f8e70352 100644 --- a/jupyter_client/connect.py +++ b/jupyter_client/connect.py @@ -17,23 +17,33 @@ import warnings from getpass import getpass from contextlib import contextmanager +from typing import Union, Optional, List, Tuple, Dict, Any, cast import zmq -from traitlets.config import LoggingConfigurable +from traitlets.config import LoggingConfigurable # type: ignore from .localinterfaces import localhost -from traitlets import ( +from traitlets import ( # type: ignore Bool, Integer, Unicode, CaselessStrEnum, Instance, Type, observe ) -from jupyter_core.paths import jupyter_data_dir, jupyter_runtime_dir, secure_write +from jupyter_core.paths import jupyter_data_dir, jupyter_runtime_dir, secure_write # type: ignore from .utils import _filefind -def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, hb_port=0, - control_port=0, ip='', key=b'', transport='tcp', - signature_scheme='hmac-sha256', kernel_name='' - ): +def write_connection_file( + fname: Optional[str] = None, + shell_port: int = 0, + iopub_port: int = 0, + stdin_port: int = 0, + hb_port: int = 0, + control_port: int = 0, + ip: str = '', + key: bytes = b'', + transport: str = 'tcp', + signature_scheme: str = 'hmac-sha256', + kernel_name: str = '' +) -> Tuple[str, Dict[str, Union[int, str]]]: """Generates a JSON config file, including the selection of random ports. Parameters @@ -83,7 +93,8 @@ def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, # Find open ports as necessary. - ports = [] + ports: List[int] = [] + sockets: List[socket.socket] = [] ports_needed = int(shell_port <= 0) + \ int(iopub_port <= 0) + \ int(stdin_port <= 0) + \ @@ -95,11 +106,11 @@ def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, # struct.pack('ii', (0,0)) is 8 null bytes sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b'\0' * 8) sock.bind((ip, 0)) - ports.append(sock) - for i, sock in enumerate(ports): + sockets.append(sock) + for sock in sockets: port = sock.getsockname()[1] sock.close() - ports[i] = port + ports.append(port) else: N = 1 for i in range(ports_needed): @@ -118,7 +129,7 @@ def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, if hb_port <= 0: hb_port = ports.pop(0) - cfg = dict( shell_port=shell_port, + cfg: Dict[str, Union[int, str]] = dict( shell_port=shell_port, iopub_port=iopub_port, stdin_port=stdin_port, control_port=control_port, @@ -165,7 +176,11 @@ def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, return fname, cfg -def find_connection_file(filename='kernel-*.json', path=None, profile=None): +def find_connection_file( + filename: str ='kernel-*.json', + path: Optional[Union[str, List[str]]] = None, + profile: Optional[str] = None +) -> str: """find a connection file, and return its absolute path. The current working directory and optional search path @@ -222,7 +237,11 @@ def find_connection_file(filename='kernel-*.json', path=None, profile=None): return sorted(matches, key=lambda f: os.stat(f).st_atime)[-1] -def tunnel_to_kernel(connection_info, sshserver, sshkey=None): +def tunnel_to_kernel( + connection_info: Union[str, Dict[str, Any]], + sshserver: str, + sshkey: Optional[str] = None +) -> Tuple[Any, ...]: """tunnel connections to a kernel via ssh This will open five SSH tunnels from localhost on this machine to the @@ -254,7 +273,7 @@ def tunnel_to_kernel(connection_info, sshserver, sshkey=None): with open(connection_info) as f: connection_info = json.loads(f.read()) - cf = connection_info + cf = cast(Dict[str, Any], connection_info) lports = tunnel.select_random_ports(5) rports = cf['shell_port'], cf['iopub_port'], cf['stdin_port'], cf['hb_port'], cf['control_port'] @@ -262,11 +281,11 @@ def tunnel_to_kernel(connection_info, sshserver, sshkey=None): remote_ip = cf['ip'] if tunnel.try_passwordless_ssh(sshserver, sshkey): - password=False + password: Union[bool, str] = False else: password = getpass("SSH Password for %s: " % sshserver) - for lp,rp in zip(lports, rports): + for lp, rp in zip(lports, rports): tunnel.ssh_tunnel(lp, rp, sshserver, remote_ip, sshkey, password) return tuple(lports) @@ -341,10 +360,10 @@ def _ip_changed(self, change): help="set the control (ROUTER) port [default: random]") # names of the ports with random assignment - _random_port_names = None + _random_port_names: Optional[List[str]] = None @property - def ports(self): + def ports(self) -> List[int]: return [ getattr(self, name) for name in port_names ] # The Session to use for communication with the kernel. @@ -357,7 +376,10 @@ def _session_default(self): # Connection and ipc file management #-------------------------------------------------------------------------- - def get_connection_info(self, session=False): + def get_connection_info( + self, + session: bool =False + ) -> Dict[str, Any]: """Return the connection info as a dict Parameters @@ -403,7 +425,7 @@ def blocking_client(self): bc.session.key = self.session.key return bc - def cleanup_connection_file(self): + def cleanup_connection_file(self) -> None: """Cleanup connection file *if we wrote it* Will not raise if the connection file was already removed somehow. @@ -416,7 +438,7 @@ def cleanup_connection_file(self): except (IOError, OSError, AttributeError): pass - def cleanup_ipc_files(self): + def cleanup_ipc_files(self) -> None: """Cleanup ipc files if we wrote them.""" if self.transport != 'ipc': return @@ -427,7 +449,7 @@ def cleanup_ipc_files(self): except (IOError, OSError): pass - def _record_random_port_names(self): + def _record_random_port_names(self) -> None: """Records which of the ports are randomly assigned. Records on first invocation, if the transport is tcp. @@ -443,7 +465,7 @@ def _record_random_port_names(self): if getattr(self, name) <= 0: self._random_port_names.append(name) - def cleanup_random_ports(self): + def cleanup_random_ports(self) -> None: """Forgets randomly assigned port numbers and cleans up the connection file. Does nothing if no port numbers have been randomly assigned. @@ -458,7 +480,7 @@ def cleanup_random_ports(self): self.cleanup_connection_file() - def write_connection_file(self): + def write_connection_file(self) -> None: """Write connection info to JSON dict in self.connection_file.""" if self._connection_file_written and os.path.exists(self.connection_file): return @@ -478,7 +500,10 @@ def write_connection_file(self): self._connection_file_written = True - def load_connection_file(self, connection_file=None): + def load_connection_file( + self, + connection_file: Optional[str] = None + ) -> None: """Load connection info from JSON dict in self.connection_file. Parameters @@ -494,7 +519,10 @@ def load_connection_file(self, connection_file=None): info = json.load(f) self.load_connection_info(info) - def load_connection_info(self, info): + def load_connection_info( + self, + info: Dict[str, int] + ) -> None: """Load connection info from a dict containing connection info. Typically this data comes from a connection file @@ -529,7 +557,10 @@ def load_connection_info(self, info): # Creating connected sockets #-------------------------------------------------------------------------- - def _make_url(self, channel): + def _make_url( + self, + channel: str + ) -> str: """Make a ZeroMQ URL for a given channel.""" transport = self.transport ip = self.ip @@ -540,7 +571,11 @@ def _make_url(self, channel): else: return "%s://%s-%s" % (transport, ip, port) - def _create_connected_socket(self, channel, identity=None): + def _create_connected_socket( + self, + channel: str, + identity: Optional[bytes] = None + ) -> zmq.sugar.socket.Socket: """Create a zmq Socket and connect it to the kernel.""" url = self._make_url(channel) socket_type = channel_socket_types[channel] @@ -553,25 +588,40 @@ def _create_connected_socket(self, channel, identity=None): sock.connect(url) return sock - def connect_iopub(self, identity=None): + def connect_iopub( + self, + identity: Optional[bytes] = None + ) -> zmq.sugar.socket.Socket: """return zmq Socket connected to the IOPub channel""" sock = self._create_connected_socket('iopub', identity=identity) sock.setsockopt(zmq.SUBSCRIBE, b'') return sock - def connect_shell(self, identity=None): + def connect_shell( + self, + identity: Optional[bytes] = None + ) -> zmq.sugar.socket.Socket: """return zmq Socket connected to the Shell channel""" return self._create_connected_socket('shell', identity=identity) - def connect_stdin(self, identity=None): + def connect_stdin( + self, + identity: Optional[bytes] = None + ) -> zmq.sugar.socket.Socket: """return zmq Socket connected to the StdIn channel""" return self._create_connected_socket('stdin', identity=identity) - def connect_hb(self, identity=None): + def connect_hb( + self, + identity: Optional[bytes] = None + ) -> zmq.sugar.socket.Socket: """return zmq Socket connected to the Heartbeat channel""" return self._create_connected_socket('hb', identity=identity) - def connect_control(self, identity=None): + def connect_control( + self, + identity: Optional[bytes] = None + ) -> zmq.sugar.socket.Socket: """return zmq Socket connected to the Control channel""" return self._create_connected_socket('control', identity=identity) diff --git a/jupyter_client/consoleapp.py b/jupyter_client/consoleapp.py index 42ce2fb77..e491dcc24 100644 --- a/jupyter_client/consoleapp.py +++ b/jupyter_client/consoleapp.py @@ -13,14 +13,15 @@ import sys import uuid import warnings +from typing import cast -from traitlets.config.application import boolean_flag -from traitlets import ( +from traitlets.config.application import boolean_flag # type: ignore +from traitlets import ( # type: ignore Dict, List, Unicode, CUnicode, CBool, Any, Type ) -from jupyter_core.application import base_flags, base_aliases +from jupyter_core.application import base_flags, base_aliases # type: ignore from .blocking import BlockingKernelClient from .restarter import KernelRestarter @@ -93,19 +94,19 @@ class JupyterConsoleApp(ConnectionFileMixin): description = """ The Jupyter Console Mixin. - + This class contains the common portions of console client (QtConsole, ZMQ-based terminal console, etc). It is not a full console, in that launched terminal subprocesses will not be able to accept input. - + The Console using this mixing supports various extra features beyond the single-process Terminal IPython shell, such as connecting to existing kernel, via: - + jupyter console --existing - + as well as tunnel via SSH - + """ classes = classes @@ -121,13 +122,13 @@ class JupyterConsoleApp(ConnectionFileMixin): kernel_argv = List(Unicode()) # connection info: - + sshserver = Unicode('', config=True, help="""The SSH server to use to connect to the kernel.""") sshkey = Unicode('', config=True, help="""Path to the ssh key to use for logging in to the ssh server.""") - - def _connection_file_default(self): + + def _connection_file_default(self) -> str: return 'kernel-%i.json' % os.getpid() existing = CUnicode('', config=True, @@ -141,26 +142,26 @@ def _connection_file_default(self): Set to display confirmation dialog on exit. You can always use 'exit' or 'quit', to force a direct exit without any confirmation.""", ) - - def build_kernel_argv(self, argv=None): + + def build_kernel_argv(self, argv=None) -> None: """build argv to be passed to kernel subprocess - + Override in subclasses if any args should be passed to the kernel """ self.kernel_argv = self.extra_args - - def init_connection_file(self): + + def init_connection_file(self) -> None: """find the connection file, and load the info if found. - + The current working directory and the current profile's security directory will be searched for the file if it is not given by absolute path. - + When attempting to connect to an existing kernel and the `--existing` argument does not match an existing file, it will be interpreted as a fileglob, and the matching file in the current profile's security dir with the latest access time will be used. - + After this method is called, self.connection_file contains the *full path* to the connection file, never just its name. """ @@ -192,7 +193,7 @@ def init_connection_file(self): except IOError: self.log.debug("Connection File not found: %s", self.connection_file) return - + # should load_connection_file only be used for existing? # as it is now, this allows reusing ports if an existing # file is requested @@ -201,25 +202,25 @@ def init_connection_file(self): except Exception: self.log.error("Failed to load connection file: %r", self.connection_file, exc_info=True) self.exit(1) - - def init_ssh(self): + + def init_ssh(self) -> None: """set up ssh tunnels, if needed.""" if not self.existing or (not self.sshserver and not self.sshkey): return self.load_connection_file() - + transport = self.transport ip = self.ip - + if transport != 'tcp': self.log.error("Can only use ssh tunnels with TCP sockets, not %s", transport) sys.exit(-1) - + if self.sshkey and not self.sshserver: # specifying just the key implies that we are connecting directly self.sshserver = ip ip = localhost() - + # build connection dict for tunnels: info = dict(ip=ip, shell_port=self.shell_port, @@ -228,9 +229,9 @@ def init_ssh(self): hb_port=self.hb_port, control_port=self.control_port ) - + self.log.info("Forwarding connections to %s via %s"%(ip, self.sshserver)) - + # tunnels return a new set of ports, which will be on localhost: self.ip = localhost() try: @@ -239,17 +240,17 @@ def init_ssh(self): # even catch KeyboardInterrupt self.log.error("Could not setup tunnels", exc_info=True) self.exit(1) - + self.shell_port, self.iopub_port, self.stdin_port, self.hb_port, self.control_port = newports - + cf = self.connection_file root, ext = os.path.splitext(cf) self.connection_file = root + '-ssh' + ext self.write_connection_file() # write the new connection file self.log.info("To connect another client via this tunnel, use:") self.log.info("--existing %s" % os.path.basename(self.connection_file)) - - def _new_connection_file(self): + + def _new_connection_file(self) -> str: cf = '' while not cf: # we don't need a 128b id to distinguish kernels, use more readable @@ -262,7 +263,7 @@ def _new_connection_file(self): cf = cf if not os.path.exists(cf) else '' return cf - def init_kernel_manager(self): + def init_kernel_manager(self) -> None: # Don't let Qt or ZMQ swallow KeyboardInterupts. if self.existing: self.kernel_manager = None @@ -289,6 +290,7 @@ def init_kernel_manager(self): self.log.critical("Could not find kernel %s", self.kernel_name) self.exit(1) + self.kernel_manager = cast(KernelManager, self.kernel_manager) self.kernel_manager.client_factory = self.kernel_client_class kwargs = {} kwargs['extra_arguments'] = self.kernel_argv @@ -310,7 +312,7 @@ def init_kernel_manager(self): atexit.register(self.kernel_manager.cleanup_connection_file) - def init_kernel_client(self): + def init_kernel_client(self) -> None: if self.kernel_manager is not None: self.kernel_client = self.kernel_manager.client() else: @@ -331,7 +333,7 @@ def init_kernel_client(self): - def initialize(self, argv=None): + def initialize(self, argv=None) -> None: """ Classes which mix this class in should call: JupyterConsoleApp.initialize(self,argv) diff --git a/jupyter_client/jsonutil.py b/jupyter_client/jsonutil.py index d3a472fee..667e33f1f 100644 --- a/jupyter_client/jsonutil.py +++ b/jupyter_client/jsonutil.py @@ -6,6 +6,7 @@ from datetime import datetime import re import warnings +from typing import Optional, Union from dateutil.parser import parse as _dateutil_parse from dateutil.tz import tzlocal @@ -28,7 +29,7 @@ # Classes and functions #----------------------------------------------------------------------------- -def _ensure_tzinfo(dt): +def _ensure_tzinfo(dt: datetime) -> datetime: """Ensure a datetime object has tzinfo If no tzinfo is present, add tzlocal @@ -41,7 +42,7 @@ def _ensure_tzinfo(dt): dt = dt.replace(tzinfo=tzlocal()) return dt -def parse_date(s): +def parse_date(s: Optional[str]) -> Optional[Union[str, datetime]]: """parse an ISO8601 date string If it is None or not a valid ISO8601 timestamp, diff --git a/jupyter_client/kernelapp.py b/jupyter_client/kernelapp.py index 33607049c..b95afb0b0 100644 --- a/jupyter_client/kernelapp.py +++ b/jupyter_client/kernelapp.py @@ -2,9 +2,9 @@ import signal import uuid -from jupyter_core.application import JupyterApp, base_flags +from jupyter_core.application import JupyterApp, base_flags # type: ignore from tornado.ioloop import IOLoop -from traitlets import Unicode +from traitlets import Unicode # type: ignore from . import __version__ from .kernelspec import KernelSpecManager, NATIVE_KERNEL_NAME @@ -39,7 +39,7 @@ def initialize(self, argv=None): self.loop = IOLoop.current() self.loop.add_callback(self._record_started) - def setup_signals(self): + def setup_signals(self) -> None: """Shutdown on SIGTERM or SIGINT (Ctrl-C)""" if os.name == 'nt': return @@ -49,17 +49,20 @@ def shutdown_handler(signo, frame): for sig in [signal.SIGTERM, signal.SIGINT]: signal.signal(sig, shutdown_handler) - def shutdown(self, signo): + def shutdown( + self, + signo: int + ) -> None: self.log.info('Shutting down on signal %d' % signo) self.km.shutdown_kernel() self.loop.stop() - def log_connection_info(self): + def log_connection_info(self) -> None: cf = self.km.connection_file self.log.info('Connection file: %s', cf) self.log.info("To connect a client: --existing %s", os.path.basename(cf)) - def _record_started(self): + def _record_started(self) -> None: """For tests, create a file to indicate that we've started Do not rely on this except in our own tests! @@ -69,7 +72,7 @@ def _record_started(self): with open(fn, 'wb'): pass - def start(self): + def start(self) -> None: self.log.info('Starting kernel %r', self.kernel_name) try: self.km.start_kernel() diff --git a/jupyter_client/launcher.py b/jupyter_client/launcher.py index 0646a434a..930ee74b0 100644 --- a/jupyter_client/launcher.py +++ b/jupyter_client/launcher.py @@ -6,12 +6,21 @@ import os import sys from subprocess import Popen, PIPE +from typing import List, Dict, Optional -from traitlets.log import get_logger +from traitlets.log import get_logger # type: ignore -def launch_kernel(cmd, stdin=None, stdout=None, stderr=None, env=None, - independent=False, cwd=None, **kw): +def launch_kernel( + cmd: List[str], + stdin: Optional[int] = None, + stdout: Optional[int] = None, + stderr: Optional[int] = None, + env: Optional[Dict[str, str]] = None, + independent: bool = False, + cwd: Optional[str] = None, + **kw +) -> Popen: """ Launches a localhost kernel, binding to the specified ports. Parameters @@ -90,11 +99,11 @@ def launch_kernel(cmd, stdin=None, stdout=None, stderr=None, env=None, env["IPY_INTERRUPT_EVENT"] = env["JPY_INTERRUPT_EVENT"] try: - from _winapi import DuplicateHandle, GetCurrentProcess, \ - DUPLICATE_SAME_ACCESS, CREATE_NEW_PROCESS_GROUP + from _winapi import (DuplicateHandle, GetCurrentProcess, + DUPLICATE_SAME_ACCESS, CREATE_NEW_PROCESS_GROUP) except: - from _subprocess import DuplicateHandle, GetCurrentProcess, \ - DUPLICATE_SAME_ACCESS, CREATE_NEW_PROCESS_GROUP + from _subprocess import (DuplicateHandle, GetCurrentProcess, # type: ignore + DUPLICATE_SAME_ACCESS, CREATE_NEW_PROCESS_GROUP) # type: ignore # create a handle on the parent to be inherited if independent: @@ -127,8 +136,7 @@ def launch_kernel(cmd, stdin=None, stdout=None, stderr=None, env=None, try: # Allow to use ~/ in the command or its arguments - cmd = list(map(os.path.expanduser, cmd)) - + cmd = [os.path.expanduser(s) for s in cmd] proc = Popen(cmd, **kwargs) except Exception as exc: msg = ( @@ -145,11 +153,12 @@ def launch_kernel(cmd, stdin=None, stdout=None, stderr=None, env=None, if sys.platform == 'win32': # Attach the interrupt event to the Popen objet so it can be used later. - proc.win32_interrupt_event = interrupt_event + proc.win32_interrupt_event = interrupt_event # type: ignore # Clean up pipes created to work around Popen bug. if redirect_in: if stdin is None: + assert proc.stdin is not None proc.stdin.close() return proc diff --git a/jupyter_client/manager.py b/jupyter_client/manager.py index cf9fb5db1..a2c2ad60c 100644 --- a/jupyter_client/manager.py +++ b/jupyter_client/manager.py @@ -11,25 +11,29 @@ import sys import time import warnings +from subprocess import Popen +import typing as t from enum import Enum import zmq from .localinterfaces import is_local_ip, local_ips -from traitlets import ( +from traitlets import ( # type: ignore Any, Float, Instance, Unicode, List, Bool, Type, DottedObjectName, default, observe, observe_compat ) -from traitlets.utils.importstring import import_item +from traitlets.utils.importstring import import_item # type: ignore from jupyter_client import ( launch_kernel, kernelspec, + KernelClient, ) from .connect import ConnectionFileMixin from .managerabc import ( KernelManagerABC ) +from .utils import run_sync, ensure_async class _ShutdownStatus(Enum): """ @@ -55,39 +59,50 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._shutdown_status = _ShutdownStatus.Unset - _created_context = Bool(False) + _created_context: Bool = Bool(False) # The PyZMQ Context to use for communication with the kernel. - context = Instance(zmq.Context) - def _context_default(self): + context: Instance = Instance(zmq.Context) + + @default('context') + def _context_default(self) -> zmq.Context: self._created_context = True return zmq.Context() # the class to create with our `client` method - client_class = DottedObjectName('jupyter_client.blocking.BlockingKernelClient') - client_factory = Type(klass='jupyter_client.KernelClient') - def _client_factory_default(self): + client_class: DottedObjectName = DottedObjectName('jupyter_client.blocking.BlockingKernelClient') + client_factory: Type = Type(klass='jupyter_client.KernelClient') + + @default('client_factory') + def _client_factory_default(self) -> Type: return import_item(self.client_class) @observe('client_class') - def _client_class_changed(self, change): + def _client_class_changed( + self, + change: t.Dict[str, DottedObjectName] + ) -> None: self.client_factory = import_item(str(change['new'])) # The kernel process with which the KernelManager is communicating. # generally a Popen instance - kernel = Any() + kernel: Any = Any() - kernel_spec_manager = Instance(kernelspec.KernelSpecManager) + kernel_spec_manager: Instance = Instance(kernelspec.KernelSpecManager) - def _kernel_spec_manager_default(self): + @default('kernel_spec_manager') + def _kernel_spec_manager_default(self) -> kernelspec.KernelSpecManager: return kernelspec.KernelSpecManager(data_dir=self.data_dir) @observe('kernel_spec_manager') @observe_compat - def _kernel_spec_manager_changed(self, change): + def _kernel_spec_manager_changed( + self, + change: t.Dict[str, Instance] + ) -> None: self._kernel_spec = None - shutdown_wait_time = Float( + shutdown_wait_time: Float = Float( 5.0, config=True, help="Time to wait for a kernel to terminate before killing it, " "in seconds. When a shutdown request is initiated, the kernel " @@ -98,23 +113,26 @@ def _kernel_spec_manager_changed(self, change): "and kill may be equivalent on windows.", ) - kernel_name = Unicode(kernelspec.NATIVE_KERNEL_NAME) + kernel_name: Unicode = Unicode(kernelspec.NATIVE_KERNEL_NAME) @observe('kernel_name') - def _kernel_name_changed(self, change): + def _kernel_name_changed( + self, + change: t.Dict[str, Unicode] + ) -> None: self._kernel_spec = None if change['new'] == 'python': self.kernel_name = kernelspec.NATIVE_KERNEL_NAME - _kernel_spec = None + _kernel_spec: t.Optional[kernelspec.KernelSpec] = None @property - def kernel_spec(self): + def kernel_spec(self) -> t.Optional[kernelspec.KernelSpec]: if self._kernel_spec is None and self.kernel_name != '': self._kernel_spec = self.kernel_spec_manager.get_kernel_spec(self.kernel_name) return self._kernel_spec - kernel_cmd = List(Unicode(), config=True, + kernel_cmd: List = List(Unicode(), config=True, help="""DEPRECATED: Use kernel_name instead. The Popen Command to launch the kernel. @@ -132,29 +150,29 @@ def _kernel_cmd_changed(self, name, old, new): warnings.warn("Setting kernel_cmd is deprecated, use kernel_spec to " "start different kernels.") - cache_ports = Bool(help='True if the MultiKernelManager should cache ports for this KernelManager instance') + cache_ports: Bool = Bool(help='True if the MultiKernelManager should cache ports for this KernelManager instance') @default('cache_ports') - def _default_cache_ports(self): + def _default_cache_ports(self) -> bool: return self.transport == 'tcp' @property - def ipykernel(self): + def ipykernel(self) -> bool: return self.kernel_name in {'python', 'python2', 'python3'} # Protected traits - _launch_args = Any() - _control_socket = Any() + _launch_args: Any = Any() + _control_socket: Any = Any() - _restarter = Any() + _restarter: Any = Any() - autorestart = Bool(True, config=True, + autorestart: Bool = Bool(True, config=True, help="""Should we autorestart the kernel if it dies.""" ) - shutting_down = False + shutting_down: bool = False - def __del__(self): + def __del__(self) -> None: self._close_control_socket() self.cleanup_connection_file() @@ -162,19 +180,27 @@ def __del__(self): # Kernel restarter #-------------------------------------------------------------------------- - def start_restarter(self): + def start_restarter(self) -> None: pass - def stop_restarter(self): + def stop_restarter(self) -> None: pass - def add_restart_callback(self, callback, event='restart'): + def add_restart_callback( + self, + callback: t.Callable, + event: str = 'restart' + ) -> None: """register a callback to be called when a kernel is restarted""" if self._restarter is None: return self._restarter.add_callback(callback, event) - def remove_restart_callback(self, callback, event='restart'): + def remove_restart_callback( + self, + callback: t.Callable, + event: str ='restart' + ) -> None: """unregister a callback to be called when a kernel is restarted""" if self._restarter is None: return @@ -184,7 +210,7 @@ def remove_restart_callback(self, callback, event='restart'): # create a Client connected to our Kernel #-------------------------------------------------------------------------- - def client(self, **kwargs): + def client(self, **kwargs) -> KernelClient: """Create a client configured to connect to our kernel""" kw = {} kw.update(self.get_connection_info(session=True)) @@ -201,12 +227,16 @@ def client(self, **kwargs): # Kernel management #-------------------------------------------------------------------------- - def format_kernel_cmd(self, extra_arguments=None): + def format_kernel_cmd( + self, + extra_arguments: t.Optional[t.List[str]] = None + ) -> t.List[str]: """replace templated args (e.g. {connection_file})""" extra_arguments = extra_arguments or [] if self.kernel_cmd: cmd = self.kernel_cmd + extra_arguments else: + assert self.kernel_spec is not None cmd = self.kernel_spec.argv + extra_arguments if cmd and cmd[0] in {'python', @@ -239,29 +269,35 @@ def from_ns(match): """Get the key out of ns if it's there, otherwise no change.""" return ns.get(match.group(1), match.group()) - return [ pat.sub(from_ns, arg) for arg in cmd ] + return [pat.sub(from_ns, arg) for arg in cmd] - def _launch_kernel(self, kernel_cmd, **kw): + async def _async_launch_kernel( + self, + kernel_cmd: t.List[str], + **kw + ) -> Popen: """actually launch the kernel override in a subclass to launch kernel subprocesses differently """ return launch_kernel(kernel_cmd, **kw) + _launch_kernel = run_sync(_async_launch_kernel) + # Control socket used for polite kernel shutdown - def _connect_control_socket(self): + def _connect_control_socket(self) -> None: if self._control_socket is None: self._control_socket = self._create_connected_socket('control') self._control_socket.linger = 100 - def _close_control_socket(self): + def _close_control_socket(self) -> None: if self._control_socket is None: return self._control_socket.close() self._control_socket = None - def pre_start_kernel(self, **kw): + def pre_start_kernel(self, **kw) -> t.Tuple[t.List[str], t.Dict[str, t.Any]]: """Prepares a kernel for startup in a separate process. If random ports (port=0) are being used, this method must be called @@ -297,12 +333,17 @@ def pre_start_kernel(self, **kw): if not self.kernel_cmd: # If kernel_cmd has been set manually, don't refer to a kernel spec. # Environment variables from kernel spec are added to os.environ. + assert self.kernel_spec is not None env.update(self._get_env_substitutions(self.kernel_spec.env, env)) kw['env'] = env return kernel_cmd, kw - def _get_env_substitutions(self, templated_env, substitution_values): + def _get_env_substitutions( + self, + templated_env: t.Optional[t.Dict[str, str]], + substitution_values: t.Dict[str, str] + ) -> t.Optional[t.Dict[str, str]]: """ Walks env entries in templated_env and applies possible substitutions from current env (represented by substitution_values). Returns the substituted list of env entries. @@ -318,11 +359,11 @@ def _get_env_substitutions(self, templated_env, substitution_values): substituted_env.update({k: Template(v).safe_substitute(substitution_values)}) return substituted_env - def post_start_kernel(self, **kw): + def post_start_kernel(self, **kw) -> None: self.start_restarter() self._connect_control_socket() - def start_kernel(self, **kw): + async def _async_start_kernel(self, **kw): """Starts a kernel on this host in a separate process. If random ports (port=0) are being used, this method must be called @@ -338,10 +379,15 @@ def start_kernel(self, **kw): # launch the kernel subprocess self.log.debug("Starting kernel: %s", kernel_cmd) - self.kernel = self._launch_kernel(kernel_cmd, **kw) + self.kernel = await ensure_async(self._launch_kernel(kernel_cmd, **kw)) self.post_start_kernel(**kw) - def request_shutdown(self, restart=False): + start_kernel = run_sync(_async_start_kernel) + + def request_shutdown( + self, + restart: bool = False + ) -> None: """Send a shutdown request via control channel """ content = dict(restart=restart) @@ -350,7 +396,11 @@ def request_shutdown(self, restart=False): self._connect_control_socket() self.session.send(self._control_socket, msg) - def finish_shutdown(self, waittime=None, pollinterval=0.1): + async def _async_finish_shutdown( + self, + waittime: t.Optional[float] = None, + pollinterval: float = 0.1 + ) -> None: """Wait for kernel shutdown, then kill process if it doesn't shutdown. This does not send shutdown requests - use :meth:`request_shutdown` @@ -359,46 +409,36 @@ def finish_shutdown(self, waittime=None, pollinterval=0.1): if waittime is None: waittime = max(self.shutdown_wait_time, 0) self._shutdown_status = _ShutdownStatus.ShutdownRequest - - def poll_or_sleep_to_kernel_gone(): - """ - Poll until the kernel is not responding, - then wait (the subprocess), until process gone. - - After this function the kernel is either: - - still responding; or - - subprocess has been culled. - """ - if self.is_alive(): - time.sleep(pollinterval) - else: - # If there's still a proc, wait and clear - if self.has_kernel: - self.kernel.wait() - self.kernel = None - return True - - # wait 50% of the shutdown timeout... - for i in range(int(waittime / 2 / pollinterval)): - if poll_or_sleep_to_kernel_gone(): - break - else: - # if we've exited the loop normally (no break) - # send sigterm and wait the other 50%. + try: + await asyncio.wait_for( + self._async_wait(pollinterval=pollinterval), timeout=waittime / 2 + ) + except asyncio.TimeoutError: self.log.debug("Kernel is taking too long to finish, terminating") self._shutdown_status = _ShutdownStatus.SigtermRequest - self._send_kernel_sigterm() - for i in range(int(waittime / 2 / pollinterval)): - if poll_or_sleep_to_kernel_gone(): - break - else: - # OK, we've waited long enough. - if self.has_kernel: - self.log.debug("Kernel is taking too long to finish, killing") - self._shutdown_status = _ShutdownStatus.SigkillRequest - self._kill_kernel() + await self._async_send_kernel_sigterm() - def cleanup_resources(self, restart=False): + try: + await asyncio.wait_for( + self._async_wait(pollinterval=pollinterval), timeout=waittime / 2 + ) + except asyncio.TimeoutError: + self.log.debug("Kernel is taking too long to finish, killing") + self._shutdown_status = _ShutdownStatus.SigkillRequest + await ensure_async(self._kill_kernel()) + else: + # Process is no longer alive, wait and clear + if self.kernel is not None: + while self.kernel.poll() is None: + await asyncio.sleep(pollinterval) + self.kernel = None + + finish_shutdown = run_sync(_async_finish_shutdown) + + def cleanup_resources( + self, + restart: bool = False + ) -> None: """Clean up resources when the kernel is shut down""" if not restart: self.cleanup_connection_file() @@ -410,13 +450,20 @@ def cleanup_resources(self, restart=False): if self._created_context and not restart: self.context.destroy(linger=100) - def cleanup(self, connection_file=True): + def cleanup( + self, + connection_file: bool = True + ) -> None: """Clean up resources when the kernel is shut down""" warnings.warn("Method cleanup(connection_file=True) is deprecated, use cleanup_resources(restart=False).", FutureWarning) self.cleanup_resources(restart=not connection_file) - def shutdown_kernel(self, now=False, restart=False): + async def _async_shutdown_kernel( + self, + now: bool = False, + restart: bool = False + ): """Attempts to stop the kernel process cleanly. This attempts to shutdown the kernels cleanly by: @@ -438,16 +485,16 @@ def shutdown_kernel(self, now=False, restart=False): # Stop monitoring for restarting while we shutdown. self.stop_restarter() - self.interrupt_kernel() + await ensure_async(self.interrupt_kernel()) if now: - self._kill_kernel() + await ensure_async(self._kill_kernel()) else: self.request_shutdown(restart=restart) # Don't send any additional kernel kill messages immediately, to give # the kernel a chance to properly execute shutdown actions. Wait for at # most 1s, checking every 0.1s. - self.finish_shutdown() + await ensure_async(self.finish_shutdown()) # In 6.1.5, a new method, cleanup_resources(), was introduced to address # a leak issue (https://github.com/jupyter/jupyter_client/pull/548) and @@ -470,7 +517,14 @@ def shutdown_kernel(self, now=False, restart=False): else: self.cleanup_resources(restart=restart) - def restart_kernel(self, now=False, newports=False, **kw): + shutdown_kernel = run_sync(_async_shutdown_kernel) + + async def _async_restart_kernel( + self, + now: bool = False, + newports: bool = False, + **kw + ) -> None: """Restarts a kernel with the arguments that were used to launch it. Parameters @@ -500,21 +554,23 @@ def restart_kernel(self, now=False, newports=False, **kw): "No previous call to 'start_kernel'.") else: # Stop currently running kernel. - self.shutdown_kernel(now=now, restart=True) + await ensure_async(self.shutdown_kernel(now=now, restart=True)) if newports: self.cleanup_random_ports() # Start new kernel. self._launch_args.update(kw) - self.start_kernel(**self._launch_args) + await ensure_async(self.start_kernel(**self._launch_args)) + + restart_kernel = run_sync(_async_restart_kernel) @property - def has_kernel(self): + def has_kernel(self) -> bool: """Has a kernel been started that we are managing.""" return self.kernel is not None - def _send_kernel_sigterm(self): + async def _async_send_kernel_sigterm(self) -> None: """similar to _kill_kernel, but with sigterm (not sigkill), but do not block""" if self.has_kernel: # Signal the kernel to terminate (sends SIGTERM on Unix and @@ -524,7 +580,7 @@ def _send_kernel_sigterm(self): if hasattr(self.kernel, "terminate"): self.kernel.terminate() elif hasattr(signal, "SIGTERM"): - self.signal_kernel(signal.SIGTERM) + await self._async_signal_kernel(signal.SIGTERM) else: self.log.debug( "Cannot set term signal to kernel, no" @@ -534,7 +590,7 @@ def _send_kernel_sigterm(self): # In Windows, we will get an Access Denied error if the process # has already terminated. Ignore it. if sys.platform == "win32": - if e.winerror != 5: + if e.winerror != 5: # type: ignore raise # On Unix, we may get an ESRCH error if the process has already # terminated. Ignore it. @@ -544,268 +600,9 @@ def _send_kernel_sigterm(self): if e.errno != ESRCH: raise - def _kill_kernel(self): - """Kill the running kernel. - - This is a private method, callers should use shutdown_kernel(now=True). - """ - if self.has_kernel: - # Signal the kernel to terminate (sends SIGKILL on Unix and calls - # TerminateProcess() on Win32). - try: - if hasattr(signal, 'SIGKILL'): - self.signal_kernel(signal.SIGKILL) - else: - self.kernel.kill() - except OSError as e: - # In Windows, we will get an Access Denied error if the process - # has already terminated. Ignore it. - if sys.platform == 'win32': - if e.winerror != 5: - raise - # On Unix, we may get an ESRCH error if the process has already - # terminated. Ignore it. - else: - from errno import ESRCH - if e.errno != ESRCH: - raise - - # Block until the kernel terminates. - self.kernel.wait() - self.kernel = None - - def interrupt_kernel(self): - """Interrupts the kernel by sending it a signal. - - Unlike ``signal_kernel``, this operation is well supported on all - platforms. - """ - if self.has_kernel: - interrupt_mode = self.kernel_spec.interrupt_mode - if interrupt_mode == 'signal': - if sys.platform == 'win32': - from .win_interrupt import send_interrupt - send_interrupt(self.kernel.win32_interrupt_event) - else: - self.signal_kernel(signal.SIGINT) - - elif interrupt_mode == 'message': - msg = self.session.msg("interrupt_request", content={}) - self._connect_control_socket() - self.session.send(self._control_socket, msg) - else: - raise RuntimeError("Cannot interrupt kernel. No kernel is running!") + _send_kernel_sigterm = run_sync(_async_send_kernel_sigterm) - def signal_kernel(self, signum): - """Sends a signal to the process group of the kernel (this - usually includes the kernel and any subprocesses spawned by - the kernel). - - Note that since only SIGTERM is supported on Windows, this function is - only useful on Unix systems. - """ - if self.has_kernel: - if hasattr(os, "getpgid") and hasattr(os, "killpg"): - try: - pgid = os.getpgid(self.kernel.pid) - os.killpg(pgid, signum) - return - except OSError: - pass - self.kernel.send_signal(signum) - else: - raise RuntimeError("Cannot signal kernel. No kernel is running!") - - def is_alive(self): - """Is the kernel process still running?""" - if self.has_kernel: - if self.kernel.poll() is None: - return True - else: - return False - else: - # we don't have a kernel - return False - - -class AsyncKernelManager(KernelManager): - """Manages kernels in an asynchronous manner """ - - client_class = DottedObjectName('jupyter_client.asynchronous.AsyncKernelClient') - client_factory = Type(klass='jupyter_client.asynchronous.AsyncKernelClient') - - async def _launch_kernel(self, kernel_cmd, **kw): - """actually launch the kernel - - override in a subclass to launch kernel subprocesses differently - """ - res = launch_kernel(kernel_cmd, **kw) - return res - - async def start_kernel(self, **kw): - """Starts a kernel in a separate process in an asynchronous manner. - - If random ports (port=0) are being used, this method must be called - before the channels are created. - - Parameters - ---------- - `**kw` : optional - keyword arguments that are passed down to build the kernel_cmd - and launching the kernel (e.g. Popen kwargs). - """ - kernel_cmd, kw = self.pre_start_kernel(**kw) - - # launch the kernel subprocess - self.log.debug("Starting kernel (async): %s", kernel_cmd) - self.kernel = await self._launch_kernel(kernel_cmd, **kw) - self.post_start_kernel(**kw) - - async def finish_shutdown(self, waittime=None, pollinterval=0.1): - """Wait for kernel shutdown, then kill process if it doesn't shutdown. - - This does not send shutdown requests - use :meth:`request_shutdown` - first. - """ - if waittime is None: - waittime = max(self.shutdown_wait_time, 0) - self._shutdown_status = _ShutdownStatus.ShutdownRequest - try: - await asyncio.wait_for( - self._async_wait(pollinterval=pollinterval), timeout=waittime / 2 - ) - except asyncio.TimeoutError: - self.log.debug("Kernel is taking too long to finish, terminating") - self._shutdown_status = _ShutdownStatus.SigtermRequest - await self._send_kernel_sigterm() - - try: - await asyncio.wait_for( - self._async_wait(pollinterval=pollinterval), timeout=waittime / 2 - ) - except asyncio.TimeoutError: - self.log.debug("Kernel is taking too long to finish, killing") - self._shutdown_status = _ShutdownStatus.SigkillRequest - await self._kill_kernel() - else: - # Process is no longer alive, wait and clear - if self.kernel is not None: - self.kernel.wait() - self.kernel = None - - async def shutdown_kernel(self, now=False, restart=False): - """Attempts to stop the kernel process cleanly. - - This attempts to shutdown the kernels cleanly by: - - 1. Sending it a shutdown message over the shell channel. - 2. If that fails, the kernel is shutdown forcibly by sending it - a signal. - - Parameters - ---------- - now : bool - Should the kernel be forcible killed *now*. This skips the - first, nice shutdown attempt. - restart: bool - Will this kernel be restarted after it is shutdown. When this - is True, connection files will not be cleaned up. - """ - self.shutting_down = True # Used by restarter to prevent race condition - # Stop monitoring for restarting while we shutdown. - self.stop_restarter() - - await self.interrupt_kernel() - - if now: - await self._kill_kernel() - else: - self.request_shutdown(restart=restart) - # Don't send any additional kernel kill messages immediately, to give - # the kernel a chance to properly execute shutdown actions. Wait for at - # most 1s, checking every 0.1s. - await self.finish_shutdown() - - # See comment in KernelManager.shutdown_kernel(). - overrides_cleanup = type(self).cleanup is not AsyncKernelManager.cleanup - overrides_cleanup_resources = type(self).cleanup_resources is not AsyncKernelManager.cleanup_resources - - if overrides_cleanup and not overrides_cleanup_resources: - self.cleanup(connection_file=not restart) - else: - self.cleanup_resources(restart=restart) - - async def restart_kernel(self, now=False, newports=False, **kw): - """Restarts a kernel with the arguments that were used to launch it. - - Parameters - ---------- - now : bool, optional - If True, the kernel is forcefully restarted *immediately*, without - having a chance to do any cleanup action. Otherwise the kernel is - given 1s to clean up before a forceful restart is issued. - - In all cases the kernel is restarted, the only difference is whether - it is given a chance to perform a clean shutdown or not. - - newports : bool, optional - If the old kernel was launched with random ports, this flag decides - whether the same ports and connection file will be used again. - If False, the same ports and connection file are used. This is - the default. If True, new random port numbers are chosen and a - new connection file is written. It is still possible that the newly - chosen random port numbers happen to be the same as the old ones. - - `**kw` : optional - Any options specified here will overwrite those used to launch the - kernel. - """ - if self._launch_args is None: - raise RuntimeError("Cannot restart the kernel. " - "No previous call to 'start_kernel'.") - else: - # Stop currently running kernel. - await self.shutdown_kernel(now=now, restart=True) - - if newports: - self.cleanup_random_ports() - - # Start new kernel. - self._launch_args.update(kw) - await self.start_kernel(**self._launch_args) - return None - - async def _send_kernel_sigterm(self): - """similar to _kill_kernel, but with sigterm (not sigkill), but do not block""" - if self.has_kernel: - # Signal the kernel to terminate (sends SIGTERM on Unix and - # if the kernel is a subprocess and we are on windows; this is - # equivalent to kill - try: - if hasattr(self.kernel, "terminate"): - self.kernel.terminate() - elif hasattr(signal, "SIGTERM"): - await self.signal_kernel(signal.SIGTERM) - else: - self.log.debug( - "Cannot set term signal to kernel, no" - " `.terminate()` method and no values for SIGTERM" - ) - except OSError as e: - # In Windows, we will get an Access Denied error if the process - # has already terminated. Ignore it. - if sys.platform == "win32": - if e.winerror != 5: - raise - # On Unix, we may get an ESRCH error if the process has already - # terminated. Ignore it. - else: - from errno import ESRCH - - if e.errno != ESRCH: - raise - - async def _kill_kernel(self): + async def _async_kill_kernel(self) -> None: """Kill the running kernel. This is a private method, callers should use shutdown_kernel(now=True). @@ -815,14 +612,14 @@ async def _kill_kernel(self): # TerminateProcess() on Win32). try: if hasattr(signal, 'SIGKILL'): - await self.signal_kernel(signal.SIGKILL) + await self._async_signal_kernel(signal.SIGKILL) # type: ignore else: self.kernel.kill() except OSError as e: # In Windows, we will get an Access Denied error if the process # has already terminated. Ignore it. if sys.platform == 'win32': - if e.winerror != 5: + if e.winerror != 5: # type: ignore raise # On Unix, we may get an ESRCH error if the process has already # terminated. Ignore it. @@ -841,23 +638,27 @@ async def _kill_kernel(self): else: # Process is no longer alive, wait and clear if self.kernel is not None: - self.kernel.wait() + while self.kernel.poll() is None: + await asyncio.sleep(0.1) self.kernel = None - async def interrupt_kernel(self): + _kill_kernel = run_sync(_async_kill_kernel) + + async def _async_interrupt_kernel(self) -> None: """Interrupts the kernel by sending it a signal. Unlike ``signal_kernel``, this operation is well supported on all platforms. """ if self.has_kernel: + assert self.kernel_spec is not None interrupt_mode = self.kernel_spec.interrupt_mode if interrupt_mode == 'signal': if sys.platform == 'win32': from .win_interrupt import send_interrupt send_interrupt(self.kernel.win32_interrupt_event) else: - await self.signal_kernel(signal.SIGINT) + await self._async_signal_kernel(signal.SIGINT) elif interrupt_mode == 'message': msg = self.session.msg("interrupt_request", content={}) @@ -866,7 +667,12 @@ async def interrupt_kernel(self): else: raise RuntimeError("Cannot interrupt kernel. No kernel is running!") - async def signal_kernel(self, signum): + interrupt_kernel = run_sync(_async_interrupt_kernel) + + async def _async_signal_kernel( + self, + signum: int + ) -> None: """Sends a signal to the process group of the kernel (this usually includes the kernel and any subprocesses spawned by the kernel). @@ -877,8 +683,8 @@ async def signal_kernel(self, signum): if self.has_kernel: if hasattr(os, "getpgid") and hasattr(os, "killpg"): try: - pgid = os.getpgid(self.kernel.pid) - os.killpg(pgid, signum) + pgid = os.getpgid(self.kernel.pid) # type: ignore + os.killpg(pgid, signum) # type: ignore return except OSError: pass @@ -886,7 +692,9 @@ async def signal_kernel(self, signum): else: raise RuntimeError("Cannot signal kernel. No kernel is running!") - async def is_alive(self): + signal_kernel = run_sync(_async_signal_kernel) + + async def _async_is_alive(self) -> bool: """Is the kernel process still running?""" if self.has_kernel: if self.kernel.poll() is None: @@ -897,19 +705,45 @@ async def is_alive(self): # we don't have a kernel return False - async def _async_wait(self, pollinterval=0.1): + is_alive = run_sync(_async_is_alive) + + async def _async_wait( + self, + pollinterval: float = 0.1 + ) -> None: # Use busy loop at 100ms intervals, polling until the process is # not alive. If we find the process is no longer alive, complete # its cleanup via the blocking wait(). Callers are responsible for # issuing calls to wait() using a timeout (see _kill_kernel()). - while await self.is_alive(): + while await self._async_is_alive(): await asyncio.sleep(pollinterval) +class AsyncKernelManager(KernelManager): + # the class to create with our `client` method + client_class: DottedObjectName = DottedObjectName('jupyter_client.asynchronous.AsyncKernelClient') + client_factory: Type = Type(klass='jupyter_client.asynchronous.AsyncKernelClient') + + _launch_kernel = KernelManager._async_launch_kernel + start_kernel = KernelManager._async_start_kernel + finish_shutdown = KernelManager._async_finish_shutdown + shutdown_kernel = KernelManager._async_shutdown_kernel + restart_kernel = KernelManager._async_restart_kernel + _send_kernel_sigterm = KernelManager._async_send_kernel_sigterm + _kill_kernel = KernelManager._async_kill_kernel + interrupt_kernel = KernelManager._async_interrupt_kernel + signal_kernel = KernelManager._async_signal_kernel + is_alive = KernelManager._async_is_alive + + KernelManagerABC.register(KernelManager) -def start_new_kernel(startup_timeout=60, kernel_name='python', **kwargs): +def start_new_kernel( + startup_timeout: float =60, + kernel_name: str = 'python', + **kwargs +) -> t.Tuple[KernelManager, KernelClient]: """Start a new kernel, and return its Manager and Client""" km = KernelManager(kernel_name=kernel_name) km.start_kernel(**kwargs) @@ -925,7 +759,11 @@ def start_new_kernel(startup_timeout=60, kernel_name='python', **kwargs): return km, kc -async def start_new_async_kernel(startup_timeout=60, kernel_name='python', **kwargs): +async def start_new_async_kernel( + startup_timeout: float = 60, + kernel_name: str = 'python', + **kwargs +) -> t.Tuple[AsyncKernelManager, KernelClient]: """Start a new kernel, and return its Manager and Client""" km = AsyncKernelManager(kernel_name=kernel_name) await km.start_kernel(**kwargs) @@ -942,7 +780,7 @@ async def start_new_async_kernel(startup_timeout=60, kernel_name='python', **kwa @contextmanager -def run_kernel(**kwargs): +def run_kernel(**kwargs) -> t.Iterator[KernelClient]: """Context manager to create a kernel in a subprocess. The kernel is shut down when the context exits. diff --git a/jupyter_client/multikernelmanager.py b/jupyter_client/multikernelmanager.py index 5907d7bfa..8d8cb5d9c 100644 --- a/jupyter_client/multikernelmanager.py +++ b/jupyter_client/multikernelmanager.py @@ -7,26 +7,35 @@ import os import uuid import socket +import typing as t import zmq -from traitlets.config.configurable import LoggingConfigurable -from traitlets.utils.importstring import import_item -from traitlets import ( +from traitlets.config.configurable import LoggingConfigurable # type: ignore +from traitlets.utils.importstring import import_item # type: ignore +from traitlets import ( # type: ignore Any, Bool, Dict, DottedObjectName, Instance, Unicode, default, observe ) from .kernelspec import NATIVE_KERNEL_NAME, KernelSpecManager -from .manager import KernelManager, AsyncKernelManager +from .manager import KernelManager +from .utils import run_sync, ensure_async class DuplicateKernelError(Exception): pass -def kernel_method(f): +def kernel_method( + f: t.Callable +) -> t.Callable: """decorator for proxying MKM.method(kernel_id) to individual KMs by ID""" - def wrapped(self, kernel_id, *args, **kwargs): + def wrapped( + self, + kernel_id: str, + *args, + **kwargs + ) -> t.Union[t.Callable, t.Awaitable]: # get the kernel km = self.get_kernel(kernel_id) method = getattr(km, f.__name__) @@ -72,10 +81,10 @@ def _kernel_manager_class_changed(self, change): def _kernel_manager_factory_default(self): return self._create_kernel_manager_factory() - def _create_kernel_manager_factory(self): + def _create_kernel_manager_factory(self) -> t.Callable: kernel_manager_ctor = import_item(self.kernel_manager_class) - def create_kernel_manager(*args, **kwargs): + def create_kernel_manager(*args, **kwargs) -> KernelManager: if self.shared_context: if self.context.closed: # recreate context if closed @@ -94,7 +103,10 @@ def create_kernel_manager(*args, **kwargs): return create_kernel_manager - def _find_available_port(self, ip): + def _find_available_port( + self, + ip: str + ) -> int: while True: tmp_sock = socket.socket() tmp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b'\0' * 8) @@ -119,8 +131,10 @@ def _find_available_port(self, ip): context = Instance('zmq.Context') + _starting_kernels = Dict() + @default("context") - def _context_default(self): + def _context_default(self) -> zmq.Context: self._created_context = True return zmq.Context() @@ -140,20 +154,24 @@ def __del__(self): _kernels = Dict() - def list_kernel_ids(self): + def list_kernel_ids(self) -> t.List[str]: """Return a list of the kernel ids of the active kernels.""" # Create a copy so we can iterate over kernels in operations # that delete keys. return list(self._kernels.keys()) - def __len__(self): + def __len__(self) -> int: """Return the number of running kernels.""" return len(self.list_kernel_ids()) - def __contains__(self, kernel_id): + def __contains__(self, kernel_id) -> bool: return kernel_id in self._kernels - def pre_start_kernel(self, kernel_name, kwargs): + def pre_start_kernel( + self, + kernel_name: t.Optional[str], + kwargs + ) -> t.Tuple[KernelManager, str, str]: # kwargs should be mutable, passing it as a dict argument. kernel_id = kwargs.pop('kernel_id', self.new_kernel_id(**kwargs)) if kernel_id in self: @@ -174,7 +192,20 @@ def pre_start_kernel(self, kernel_name, kwargs): ) return km, kernel_name, kernel_id - def start_kernel(self, kernel_name=None, **kwargs): + async def _add_kernel_when_ready( + self, + kernel_id: str, + km: KernelManager, + kernel_awaitable: t.Awaitable + ) -> None: + await kernel_awaitable + self._kernels[kernel_id] = km + + async def _async_start_kernel( + self, + kernel_name: t.Optional[str] = None, + **kwargs + ) -> str: """Start a new kernel. The caller can pick a kernel_id by passing one in as a keyword arg, @@ -183,11 +214,29 @@ def start_kernel(self, kernel_name=None, **kwargs): The kernel ID for the newly started kernel is returned. """ km, kernel_name, kernel_id = self.pre_start_kernel(kernel_name, kwargs) - km.start_kernel(**kwargs) - self._kernels[kernel_id] = km + if not isinstance(km, KernelManager): + self.log.warning("Kernel manager class ({km_class}) is not an instance of 'KernelManager'!". + format(km_class=self.kernel_manager_class.__class__)) + fut = asyncio.ensure_future( + self._add_kernel_when_ready( + kernel_id, + km, + ensure_async(km.start_kernel(**kwargs)) + ) + ) + self._starting_kernels[kernel_id] = fut + await fut + del self._starting_kernels[kernel_id] return kernel_id - def shutdown_kernel(self, kernel_id, now=False, restart=False): + start_kernel = run_sync(_async_start_kernel) + + async def _async_shutdown_kernel( + self, + kernel_id: str, + now: t.Optional[bool] = False, + restart: t.Optional[bool] = False + ) -> None: """Shutdown a kernel by its kernel uuid. Parameters @@ -208,32 +257,54 @@ def shutdown_kernel(self, kernel_id, now=False, restart=False): km.hb_port, km.control_port ) - km.shutdown_kernel(now=now, restart=restart) + await ensure_async(km.shutdown_kernel(now, restart)) self.remove_kernel(kernel_id) if km.cache_ports and not restart: for port in ports: self.currently_used_ports.remove(port) + shutdown_kernel = run_sync(_async_shutdown_kernel) + @kernel_method - def request_shutdown(self, kernel_id, restart=False): + def request_shutdown( + self, + kernel_id: str, + restart: t.Optional[bool] = False + ) -> None: """Ask a kernel to shut down by its kernel uuid""" @kernel_method - def finish_shutdown(self, kernel_id, waittime=None, pollinterval=0.1): + def finish_shutdown( + self, + kernel_id: str, + waittime: t.Optional[float] = None, + pollinterval: t.Optional[float] = 0.1 + ) -> None: """Wait for a kernel to finish shutting down, and kill it if it doesn't """ self.log.info("Kernel shutdown: %s" % kernel_id) @kernel_method - def cleanup(self, kernel_id, connection_file=True): + def cleanup( + self, + kernel_id: str, + connection_file: bool = True + ) -> None: """Clean up a kernel's resources""" @kernel_method - def cleanup_resources(self, kernel_id, restart=False): + def cleanup_resources( + self, + kernel_id: str, + restart: bool = False + ) -> None: """Clean up a kernel's resources""" - def remove_kernel(self, kernel_id): + def remove_kernel( + self, + kernel_id: str + ) -> KernelManager: """remove a kernel from our mapping. Mainly so that a kernel can be removed if it is already dead, @@ -243,29 +314,35 @@ def remove_kernel(self, kernel_id): """ return self._kernels.pop(kernel_id) - def shutdown_all(self, now=False): + async def _shutdown_starting_kernel( + self, + kid: str, + now: bool + ) -> None: + if kid in self._starting_kernels: + await self._starting_kernels[kid] + await ensure_async(self.shutdown_kernel(kid, now=now)) + + async def _async_shutdown_all( + self, + now: bool = False + ) -> None: """Shutdown all kernels.""" kids = self.list_kernel_ids() - for kid in kids: - self.request_shutdown(kid) - for kid in kids: - self.finish_shutdown(kid) - - # Determine which cleanup method to call - # See comment in KernelManager.shutdown_kernel(). - km = self.get_kernel(kid) - overrides_cleanup = type(km).cleanup is not KernelManager.cleanup - overrides_cleanup_resources = type(km).cleanup_resources is not KernelManager.cleanup_resources - - if overrides_cleanup and not overrides_cleanup_resources: - km.cleanup(connection_file=True) - else: - km.cleanup_resources(restart=False) + futs = [ensure_async(self.shutdown_kernel(kid, now=now)) for kid in kids] + futs += [ + self._shutdown_starting_kernel(kid, now=now) + for kid in self._starting_kernels.keys() + ] + await asyncio.gather(*futs) - self.remove_kernel(kid) + shutdown_all = run_sync(_async_shutdown_all) @kernel_method - def interrupt_kernel(self, kernel_id): + def interrupt_kernel( + self, + kernel_id: str + ) -> None: """Interrupt (SIGINT) the kernel by its uuid. Parameters @@ -276,7 +353,11 @@ def interrupt_kernel(self, kernel_id): self.log.info("Kernel interrupted: %s" % kernel_id) @kernel_method - def signal_kernel(self, kernel_id, signum): + def signal_kernel( + self, + kernel_id: str, + signum: int + ) -> None: """Sends a signal to the kernel by its uuid. Note that since only SIGTERM is supported on Windows, this function @@ -290,7 +371,11 @@ def signal_kernel(self, kernel_id, signum): self.log.info("Signaled Kernel %s with %s" % (kernel_id, signum)) @kernel_method - def restart_kernel(self, kernel_id, now=False): + def restart_kernel( + self, + kernel_id: str, + now: bool = False + ) -> None: """Restart a kernel by its uuid, keeping the same ports. Parameters @@ -301,7 +386,10 @@ def restart_kernel(self, kernel_id, now=False): self.log.info("Kernel restarted: %s" % kernel_id) @kernel_method - def is_alive(self, kernel_id): + def is_alive( + self, + kernel_id: str + ) -> bool: """Is the kernel alive. This calls KernelManager.is_alive() which calls Popen.poll on the @@ -313,12 +401,18 @@ def is_alive(self, kernel_id): The id of the kernel. """ - def _check_kernel_id(self, kernel_id): + def _check_kernel_id( + self, + kernel_id: str + ) -> None: """check that a kernel id is valid""" if kernel_id not in self: raise KeyError("Kernel with id not found: %s" % kernel_id) - def get_kernel(self, kernel_id): + def get_kernel( + self, + kernel_id: str + ) -> KernelManager: """Get the single KernelManager object for a kernel by its uuid. Parameters @@ -330,15 +424,28 @@ def get_kernel(self, kernel_id): return self._kernels[kernel_id] @kernel_method - def add_restart_callback(self, kernel_id, callback, event='restart'): + def add_restart_callback( + self, + kernel_id: str, + callback: t.Callable, + event: str = 'restart' + ) -> None: """add a callback for the KernelRestarter""" @kernel_method - def remove_restart_callback(self, kernel_id, callback, event='restart'): + def remove_restart_callback( + self, + kernel_id: str, + callback: t.Callable, + event: str = 'restart' + ) -> None: """remove a callback for the KernelRestarter""" @kernel_method - def get_connection_info(self, kernel_id): + def get_connection_info( + self, + kernel_id: str + ) -> t.Dict[str, t.Any]: """Return a dictionary of connection data for a kernel. Parameters @@ -356,7 +463,11 @@ def get_connection_info(self, kernel_id): """ @kernel_method - def connect_iopub(self, kernel_id, identity=None): + def connect_iopub( + self, + kernel_id: str, + identity: t.Optional[bytes] = None + ) -> socket.socket: """Return a zmq Socket connected to the iopub channel. Parameters @@ -372,7 +483,11 @@ def connect_iopub(self, kernel_id, identity=None): """ @kernel_method - def connect_shell(self, kernel_id, identity=None): + def connect_shell( + self, + kernel_id: str, + identity: t.Optional[bytes] = None + ) -> socket.socket: """Return a zmq Socket connected to the shell channel. Parameters @@ -388,7 +503,11 @@ def connect_shell(self, kernel_id, identity=None): """ @kernel_method - def connect_control(self, kernel_id, identity=None): + def connect_control( + self, + kernel_id: str, + identity: t.Optional[bytes] = None + ) -> socket.socket: """Return a zmq Socket connected to the control channel. Parameters @@ -404,7 +523,11 @@ def connect_control(self, kernel_id, identity=None): """ @kernel_method - def connect_stdin(self, kernel_id, identity=None): + def connect_stdin( + self, + kernel_id: str, + identity: t.Optional[bytes] = None + ) -> socket.socket: """Return a zmq Socket connected to the stdin channel. Parameters @@ -420,7 +543,11 @@ def connect_stdin(self, kernel_id, identity=None): """ @kernel_method - def connect_hb(self, kernel_id, identity=None): + def connect_hb( + self, + kernel_id: str, + identity: t.Optional[bytes] = None + ) -> socket.socket: """Return a zmq Socket connected to the hb channel. Parameters @@ -435,7 +562,7 @@ def connect_hb(self, kernel_id, identity=None): stream : zmq Socket or ZMQStream """ - def new_kernel_id(self, **kwargs): + def new_kernel_id(self, **kwargs) -> str: """ Returns the id to associate with the kernel for this request. Subclasses may override this method to substitute other sources of kernel ids. @@ -454,121 +581,6 @@ class AsyncMultiKernelManager(MultiKernelManager): """ ) - _starting_kernels = Dict() - - async def _add_kernel_when_ready(self, kernel_id, km, kernel_awaitable): - await kernel_awaitable - self._kernels[kernel_id] = km - - async def start_kernel(self, kernel_name=None, **kwargs): - """Start a new kernel. - - The caller can pick a kernel_id by passing one in as a keyword arg, - otherwise one will be generated using new_kernel_id(). - - The kernel ID for the newly started kernel is returned. - """ - km, kernel_name, kernel_id = self.pre_start_kernel(kernel_name, kwargs) - if not isinstance(km, AsyncKernelManager): - self.log.warning("Kernel manager class ({km_class}) is not an instance of 'AsyncKernelManager'!". - format(km_class=self.kernel_manager_class.__class__)) - fut = asyncio.ensure_future( - self._add_kernel_when_ready( - kernel_id, - km, - km.start_kernel(**kwargs) - ) - ) - self._starting_kernels[kernel_id] = fut - await fut - del self._starting_kernels[kernel_id] - return kernel_id - - async def shutdown_kernel(self, kernel_id, now=False, restart=False): - """Shutdown a kernel by its kernel uuid. - - Parameters - ========== - kernel_id : uuid - The id of the kernel to shutdown. - now : bool - Should the kernel be shutdown forcibly using a signal. - restart : bool - Will the kernel be restarted? - """ - self.log.info("Kernel shutdown: %s" % kernel_id) - - km = self.get_kernel(kernel_id) - - ports = ( - km.shell_port, km.iopub_port, km.stdin_port, - km.hb_port, km.control_port - ) - - await km.shutdown_kernel(now, restart) - self.remove_kernel(kernel_id) - - if km.cache_ports and not restart: - for port in ports: - self.currently_used_ports.remove(port) - - async def finish_shutdown(self, kernel_id, waittime=None, pollinterval=0.1): - """Wait for a kernel to finish shutting down, and kill it if it doesn't - """ - km = self.get_kernel(kernel_id) - await km.finish_shutdown(waittime, pollinterval) - self.log.info("Kernel shutdown: %s" % kernel_id) - - async def interrupt_kernel(self, kernel_id): - """Interrupt (SIGINT) the kernel by its uuid. - - Parameters - ========== - kernel_id : uuid - The id of the kernel to interrupt. - """ - km = self.get_kernel(kernel_id) - await km.interrupt_kernel() - self.log.info("Kernel interrupted: %s" % kernel_id) - - async def signal_kernel(self, kernel_id, signum): - """Sends a signal to the kernel by its uuid. - - Note that since only SIGTERM is supported on Windows, this function - is only useful on Unix systems. - - Parameters - ========== - kernel_id : uuid - The id of the kernel to signal. - """ - km = self.get_kernel(kernel_id) - await km.signal_kernel(signum) - self.log.info("Signaled Kernel %s with %s" % (kernel_id, signum)) - - async def restart_kernel(self, kernel_id, now=False): - """Restart a kernel by its uuid, keeping the same ports. - - Parameters - ========== - kernel_id : uuid - The id of the kernel to interrupt. - """ - km = self.get_kernel(kernel_id) - await km.restart_kernel(now) - self.log.info("Kernel restarted: %s" % kernel_id) - - async def _shutdown_starting_kernel(self, kid, now): - if kid in self._starting_kernels: - await self._starting_kernels[kid] - await self.shutdown_kernel(kid, now=now) - - async def shutdown_all(self, now=False): - """Shutdown all kernels.""" - kids = self.list_kernel_ids() - futs = [self.shutdown_kernel(kid, now=now) for kid in kids] - futs += [ - self._shutdown_starting_kernel(kid, now=now) - for kid in self._starting_kernels.keys() - ] - await asyncio.gather(*futs) + start_kernel = MultiKernelManager._async_start_kernel + shutdown_kernel = MultiKernelManager._async_shutdown_kernel + shutdown_all = MultiKernelManager._async_shutdown_all diff --git a/jupyter_client/session.py b/jupyter_client/session.py index 88fc4746f..fde87d165 100644 --- a/jupyter_client/session.py +++ b/jupyter_client/session.py @@ -21,6 +21,7 @@ import pprint import random import warnings +import typing as t from datetime import datetime from datetime import timezone @@ -42,13 +43,13 @@ from jupyter_client import protocol_version from jupyter_client.adapter import adapt -from traitlets import ( +from traitlets import ( # type: ignore CBytes, Unicode, Bool, Any, Instance, Set, DottedObjectName, CUnicode, Dict, Integer, TraitError, observe ) -from traitlets.log import get_logger -from traitlets.utils.importstring import import_item -from traitlets.config.configurable import Configurable, LoggingConfigurable +from traitlets.log import get_logger # type: ignore +from traitlets.utils.importstring import import_item # type: ignore +from traitlets.config.configurable import Configurable, LoggingConfigurable # type: ignore #----------------------------------------------------------------------------- # utility functions @@ -98,7 +99,7 @@ def squash_unicode(obj): # Mixin tools for apps that use Sessions #----------------------------------------------------------------------------- -def new_id(): +def new_id() -> str: """Generate a new random id. Avoids problematic runtime import in stdlib uuid on Python 2. @@ -113,7 +114,7 @@ def new_id(): buf[:4], buf[4:] )) -def new_id_bytes(): +def new_id_bytes() -> bytes: """Return new_id as ascii bytes""" return new_id().encode('ascii') @@ -123,7 +124,7 @@ def new_id_bytes(): keyfile = 'Session.keyfile', ) -session_flags = { +session_flags = { 'secure' : ({'Session' : { 'key' : new_id_bytes(), 'keyfile' : '' }}, """Use HMAC digests for authentication of messages. @@ -133,7 +134,7 @@ def new_id_bytes(): """Don't authenticate messages."""), } -def default_secure(cfg): +def default_secure(cfg) -> None: """Set the default behavior for a config environment to be secure. If Session.key/keyfile have not been set, set Session.key to @@ -146,7 +147,7 @@ def default_secure(cfg): # key/keyfile not specified, generate new UUID: cfg.Session.key = new_id_bytes() -def utcnow(): +def utcnow() -> datetime: """Return timezone-aware UTC timestamp""" return datetime.utcnow().replace(tzinfo=utc) @@ -162,12 +163,12 @@ class SessionFactory(LoggingConfigurable): logname = Unicode('') @observe('logname') - def _logname_changed(self, change): + def _logname_changed(self, change) -> None: self.log = logging.getLogger(change['new']) # not configurable: context = Instance('zmq.Context') - def _context_default(self): + def _context_default(self) -> zmq.Context: return zmq.Context() session = Instance('jupyter_client.session.Session', @@ -191,7 +192,10 @@ class Message(object): A Message can be created from a dict and a dict from a Message instance simply by calling dict(msg_obj).""" - def __init__(self, msg_dict): + def __init__( + self, + msg_dict: t.Dict[str, t.Any] + ) -> None: dct = self.__dict__ for k, v in dict(msg_dict).items(): if isinstance(v, dict): @@ -199,29 +203,36 @@ def __init__(self, msg_dict): dct[k] = v # Having this iterator lets dict(msg_obj) work out of the box. - def __iter__(self): + def __iter__(self) -> t.ItemsView[str, t.Any]: return self.__dict__.items() - def __repr__(self): + def __repr__(self) -> str: return repr(self.__dict__) - def __str__(self): + def __str__(self) -> str: return pprint.pformat(self.__dict__) - def __contains__(self, k): + def __contains__(self, k) -> bool: return k in self.__dict__ - def __getitem__(self, k): + def __getitem__(self, k) -> t.Any: return self.__dict__[k] -def msg_header(msg_id, msg_type, username, session): +def msg_header( + msg_id: str, + msg_type: str, + username: str, + session: 'Session' +) -> t.Dict[str, t.Any]: """Create a new message header""" date = utcnow() version = protocol_version return locals() -def extract_header(msg_or_header): +def extract_header( + msg_or_header: t.Dict[str, t.Any] +) -> t.Dict[str, t.Any]: """Given a message or header, return the header.""" if not msg_or_header: return {} @@ -328,7 +339,7 @@ def _unpacker_changed(self, change): session = CUnicode('', config=True, help="""The UUID identifying this session.""") - def _session_default(self): + def _session_default(self) -> str: u = new_id() self.bsession = u.encode('ascii') return u @@ -355,7 +366,7 @@ def _session_changed(self, change): key = CBytes(config=True, help="""execution key, for signing messages.""") - def _key_default(self): + def _key_default(self) -> bytes: return new_id_bytes() @observe('key') @@ -380,12 +391,12 @@ def _signature_scheme_changed(self, change): self._new_auth() digest_mod = Any() - def _digest_mod_default(self): + def _digest_mod_default(self) -> t.Callable: return hashlib.sha256 auth = Instance(hmac.HMAC, allow_none=True) - def _new_auth(self): + def _new_auth(self) -> None: if self.key: self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod) else: @@ -491,7 +502,7 @@ def __init__(self, **kwargs): if not self.key: get_logger().warning("Message signing is disabled. This is insecure and not recommended!") - def clone(self): + def clone(self) -> 'Session': """Create a copy of this Session Useful when connecting multiple times to a given kernel. @@ -511,28 +522,28 @@ def clone(self): message_count = 0 @property - def msg_id(self): + def msg_id(self) -> str: message_number = self.message_count self.message_count += 1 return '{}_{}'.format(self.session, message_number) - def _check_packers(self): + def _check_packers(self) -> None: """check packers for datetime support.""" pack = self.pack unpack = self.unpack # check simple serialization - msg = dict(a=[1,'hi']) + msg_list = dict(a=[1,'hi']) try: - packed = pack(msg) + packed = pack(msg_list) except Exception as e: - msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}" + error_msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}" if self.packer == 'json': jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod else: jsonmsg = "" raise ValueError( - msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg) + error_msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg) ) from e # ensure packed message is bytes @@ -542,31 +553,41 @@ def _check_packers(self): # check that unpack is pack's inverse try: unpacked = unpack(packed) - assert unpacked == msg + assert unpacked == msg_list except Exception as e: - msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}" + error_msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}" if self.packer == 'json': jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod else: jsonmsg = "" raise ValueError( - msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg) + error_msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg) ) from e # check datetime support - msg = dict(t=utcnow()) + msg_datetime = dict(t=utcnow()) try: - unpacked = unpack(pack(msg)) + unpacked = unpack(pack(msg_datetime)) if isinstance(unpacked['t'], datetime): raise ValueError("Shouldn't deserialize to datetime") except Exception: self.pack = lambda o: pack(squash_dates(o)) self.unpack = lambda s: unpack(s) - def msg_header(self, msg_type): + def msg_header( + self, + msg_type: str + ) -> t.Dict[str, t.Any]: return msg_header(self.msg_id, msg_type, self.username, self.session) - def msg(self, msg_type, content=None, parent=None, header=None, metadata=None): + def msg( + self, + msg_type: str, + content: t.Optional[t.Dict] = None, + parent: t.Optional[t.Dict[str, t.Any]] = None, + header: t.Optional[t.Dict[str, t.Any]] = None, + metadata: t.Optional[t.Dict[str, t.Any]] = None + ) -> t.Dict[str, t.Any]: """Return the nested message dict. This format is different from what is sent over the wire. The @@ -585,7 +606,10 @@ def msg(self, msg_type, content=None, parent=None, header=None, metadata=None): msg['metadata'].update(metadata) return msg - def sign(self, msg_list): + def sign( + self, + msg_list: t.List + ) -> bytes: """Sign a message with HMAC digest. If no auth, return b''. Parameters @@ -600,7 +624,11 @@ def sign(self, msg_list): h.update(m) return h.hexdigest().encode() - def serialize(self, msg, ident=None): + def serialize( + self, + msg: t.Dict[str, t.Any], + ident: t.Optional[t.Union[t.List[bytes], bytes]] = None + ) -> t.List[bytes]: """Serialize the message components to bytes. This is roughly the inverse of deserialize. The serialize/deserialize @@ -659,8 +687,18 @@ def serialize(self, msg, ident=None): return to_send - def send(self, stream, msg_or_type, content=None, parent=None, ident=None, - buffers=None, track=False, header=None, metadata=None): + def send( + self, + stream: zmq.sugar.socket.Socket, + msg_or_type: t.Union[t.Dict[str, t.Any], str], + content: t.Optional[t.Dict[str, t.Any]] = None, + parent: t.Optional[t.Dict[str, t.Any]] = None, + ident: t.Optional[t.Union[bytes, t.List[bytes]]] = None, + buffers: t.Optional[t.List[bytes]] = None, + track: bool = False, + header: t.Optional[t.Dict[str, t.Any]] = None, + metadata: t.Optional[t.Dict[str, t.Any]] = None + ) -> t.Optional[t.Dict[str, t.Any]]: """Build and send a message via stream or socket. The message format used by this function internally is as follows: @@ -720,7 +758,7 @@ def send(self, stream, msg_or_type, content=None, parent=None, ident=None, get_logger().warning("WARNING: attempted to send message from fork\n%s", msg ) - return + return None buffers = [] if buffers is None else buffers for idx, buf in enumerate(buffers): if isinstance(buf, memoryview): @@ -761,7 +799,14 @@ def send(self, stream, msg_or_type, content=None, parent=None, ident=None, return msg - def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None): + def send_raw( + self, + stream: zmq.sugar.socket.Socket, + msg_list: t.List, + flags: int = 0, + copy: bool = True, + ident: t.Optional[t.Union[bytes, t.List[bytes]]] = None, + ) -> None: """Send a raw message via ident path. This method is used to send a already serialized message. @@ -789,7 +834,13 @@ def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None): to_send.extend(msg_list) stream.send_multipart(to_send, flags, copy=copy) - def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True): + def recv( + self, + socket: zmq.sugar.socket.Socket, + mode: int =zmq.NOBLOCK, + content: bool =True, + copy: bool = True + ) -> t.Tuple[t.Optional[t.List[bytes]], t.Optional[t.Dict[str, t.Any]]]: """Receive and unpack a message. Parameters @@ -811,7 +862,7 @@ def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True): if e.errno == zmq.EAGAIN: # We can convert EAGAIN to None as we know in this case # recv_multipart won't return None. - return None,None + return None, None else: raise # split multipart message into identity list and message dict @@ -823,7 +874,11 @@ def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True): # TODO: handle it raise e - def feed_identities(self, msg_list, copy=True): + def feed_identities( + self, + msg_list: t.Union[t.List[bytes], t.List[zmq.Message]], + copy: bool =True + ) -> t.Tuple[t.List[bytes], t.Union[t.List[bytes], t.List[zmq.Message]]]: """Split the identities from the rest of the message. Feed until DELIM is reached, then return the prefix as idents and @@ -847,20 +902,25 @@ def feed_identities(self, msg_list, copy=True): point. """ if copy: + msg_list = t.cast(t.List[bytes], msg_list) idx = msg_list.index(DELIM) return msg_list[:idx], msg_list[idx+1:] else: + msg_list = t.cast(t.List[zmq.Message], msg_list) failed = True - for idx,m in enumerate(msg_list): + for idx, m in enumerate(msg_list): if m.bytes == DELIM: failed = False break if failed: raise ValueError("DELIM not in msg_list") idents, msg_list = msg_list[:idx], msg_list[idx+1:] - return [m.bytes for m in idents], msg_list + return [bytes(m.bytes) for m in idents], msg_list - def _add_digest(self, signature): + def _add_digest( + self, + signature: bytes + ) -> None: """add a digest to history to protect against replay attacks""" if self.digest_history_size == 0: # no history, never add digests @@ -871,7 +931,7 @@ def _add_digest(self, signature): # threshold reached, cull 10% self._cull_digest_history() - def _cull_digest_history(self): + def _cull_digest_history(self) -> None: """cull the digest history Removes a randomly selected 10% of the digest history @@ -884,7 +944,12 @@ def _cull_digest_history(self): to_cull = random.sample(tuple(sorted(self.digest_history)), n_to_cull) self.digest_history.difference_update(to_cull) - def deserialize(self, msg_list, content=True, copy=True): + def deserialize( + self, + msg_list: t.Union[t.List[bytes], t.List[zmq.Message]], + content: bool =True, + copy: bool =True + ) -> t.Dict[str, t.Any]: """Unserialize a msg_list to a nested message dict. This is roughly the inverse of serialize. The serialize/deserialize @@ -913,10 +978,13 @@ def deserialize(self, msg_list, content=True, copy=True): message = {} if not copy: # pyzmq didn't copy the first parts of the message, so we'll do it - for i in range(minlen): - msg_list[i] = msg_list[i].bytes + msg_list = t.cast(t.List[zmq.Message], msg_list) + msg_list_beginning = [bytes(msg.bytes) for msg in msg_list[:minlen]] + msg_list = t.cast(t.List[bytes], msg_list) + msg_list = msg_list_beginning + msg_list[minlen:] + msg_list = t.cast(t.List[bytes], msg_list) if self.auth is not None: - signature = msg_list[0] + signature = t.cast(bytes, msg_list[0]) if not signature: raise ValueError("Unsigned Message") if signature in self.digest_history: @@ -942,14 +1010,15 @@ def deserialize(self, msg_list, content=True, copy=True): buffers = [memoryview(b) for b in msg_list[5:]] if buffers and buffers[0].shape is None: # force copy to workaround pyzmq #646 - buffers = [memoryview(b.bytes) for b in msg_list[5:]] + msg_list = t.cast(t.List[zmq.Message], msg_list) + buffers = [memoryview(bytes(b.bytes)) for b in msg_list[5:]] message['buffers'] = buffers if self.debug: pprint.pprint(message) # adapt to the current version return adapt(message) - def unserialize(self, *args, **kwargs): + def unserialize(self, *args, **kwargs) -> t.Dict[str, t.Any]: warnings.warn( "Session.unserialize is deprecated. Use Session.deserialize.", DeprecationWarning, diff --git a/jupyter_client/tests/test_kernelapp.py b/jupyter_client/tests/test_kernelapp.py index 17f793a98..af28f814a 100644 --- a/jupyter_client/tests/test_kernelapp.py +++ b/jupyter_client/tests/test_kernelapp.py @@ -35,13 +35,21 @@ def test_kernelapp_lifecycle(): .format(WAIT_TIME)) # Connection file should be there by now - files = os.listdir(runtime_dir) + for _ in range(WAIT_TIME * POLL_FREQ): + files = os.listdir(runtime_dir) + if files: + break + time.sleep(1 / POLL_FREQ) + else: + raise AssertionError("No connection file created in {} seconds" + .format(WAIT_TIME)) assert len(files) == 1 cf = files[0] assert cf.startswith('kernel') assert cf.endswith('.json') # Send SIGTERM to shut down + time.sleep(1) p.terminate() _, stderr = p.communicate(timeout=WAIT_TIME) assert cf in stderr.decode('utf-8', 'replace') diff --git a/jupyter_client/tests/test_kernelmanager.py b/jupyter_client/tests/test_kernelmanager.py index 7f5ea13dc..5380bdca0 100644 --- a/jupyter_client/tests/test_kernelmanager.py +++ b/jupyter_client/tests/test_kernelmanager.py @@ -10,10 +10,10 @@ import signal import sys import time -import threading -import multiprocessing as mp +import concurrent.futures import pytest +import nest_asyncio from async_generator import async_generator, yield_ from traitlets.config.loader import Config from jupyter_core import paths @@ -242,8 +242,8 @@ def execute(cmd): content = reply['content'] assert content['status'] == 'ok' assert content['user_expressions']['interrupted'] - # wait up to 5s for subprocesses to handle signal - for i in range(50): + # wait up to 10s for subprocesses to handle signal + for i in range(100): reply = execute('check') if reply['user_expressions']['poll'] != [-signal.SIGINT] * N: time.sleep(0.1) @@ -350,41 +350,34 @@ def test_start_parallel_thread_kernels(self, config, install_kernel): pytest.skip("IPC transport is currently not working for this test!") self._run_signaltest_lifecycle(config) - thread = threading.Thread(target=self._run_signaltest_lifecycle, args=(config,)) - thread2 = threading.Thread(target=self._run_signaltest_lifecycle, args=(config,)) - try: - thread.start() - thread2.start() - finally: - thread.join() - thread2.join() + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as thread_executor: + future1 = thread_executor.submit(self._run_signaltest_lifecycle, config) + future2 = thread_executor.submit(self._run_signaltest_lifecycle, config) + future1.result() + future2.result() @pytest.mark.timeout(TIMEOUT) + @pytest.mark.skipif((sys.platform == 'darwin') and (sys.version_info >= (3, 6)) and (sys.version_info < (3, 8)), reason='"Bad file descriptor" error') def test_start_parallel_process_kernels(self, config, install_kernel): if config.KernelManager.transport == 'ipc': # FIXME pytest.skip("IPC transport is currently not working for this test!") self._run_signaltest_lifecycle(config) - thread = threading.Thread(target=self._run_signaltest_lifecycle, args=(config,)) - proc = mp.Process(target=self._run_signaltest_lifecycle, args=(config,)) - try: - thread.start() - proc.start() - finally: - thread.join() - proc.join() - - assert proc.exitcode == 0 + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_executor: + future1 = thread_executor.submit(self._run_signaltest_lifecycle, config) + with concurrent.futures.ProcessPoolExecutor(max_workers=1) as process_executor: + future2 = process_executor.submit(self._run_signaltest_lifecycle, config) + future2.result() + future1.result() @pytest.mark.timeout(TIMEOUT) + @pytest.mark.skipif((sys.platform == 'darwin') and (sys.version_info >= (3, 6)) and (sys.version_info < (3, 8)), reason='"Bad file descriptor" error') def test_start_sequence_process_kernels(self, config, install_kernel): + if config.KernelManager.transport == 'ipc': # FIXME + pytest.skip("IPC transport is currently not working for this test!") self._run_signaltest_lifecycle(config) - proc = mp.Process(target=self._run_signaltest_lifecycle, args=(config,)) - try: - proc.start() - finally: - proc.join() - - assert proc.exitcode == 0 + with concurrent.futures.ProcessPoolExecutor(max_workers=1) as pool_executor: + future = pool_executor.submit(self._run_signaltest_lifecycle, config) + future.result() def _prepare_kernel(self, km, startup_timeout=TIMEOUT, **kwargs): km.start_kernel(**kwargs) diff --git a/jupyter_client/tests/test_multikernelmanager.py b/jupyter_client/tests/test_multikernelmanager.py index ff13e8282..fba0eff04 100644 --- a/jupyter_client/tests/test_multikernelmanager.py +++ b/jupyter_client/tests/test_multikernelmanager.py @@ -1,15 +1,16 @@ """Tests for the notebook kernel and session manager.""" import asyncio -import threading +import concurrent.futures import uuid -import multiprocessing as mp +import sys +import pytest from subprocess import PIPE from unittest import TestCase from tornado.testing import AsyncTestCase, gen_test from traitlets.config.loader import Config -from jupyter_client import KernelManager +from jupyter_client import KernelManager, AsyncKernelManager from jupyter_client.multikernelmanager import MultiKernelManager, AsyncMultiKernelManager from .utils import skip_win32, SyncMKMSubclass, AsyncMKMSubclass, SyncKMSubclass, AsyncKMSubclass from ..localinterfaces import localhost @@ -134,30 +135,23 @@ def tcp_lifecycle_with_loop(self): def test_start_parallel_thread_kernels(self): self.test_tcp_lifecycle() - thread = threading.Thread(target=self.tcp_lifecycle_with_loop) - thread2 = threading.Thread(target=self.tcp_lifecycle_with_loop) - try: - thread.start() - thread2.start() - finally: - thread.join() - thread2.join() + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as thread_executor: + future1 = thread_executor.submit(self.tcp_lifecycle_with_loop) + future2 = thread_executor.submit(self.tcp_lifecycle_with_loop) + future1.result() + future2.result() + @pytest.mark.skipif((sys.platform == 'darwin') and (sys.version_info >= (3, 6)) and (sys.version_info < (3, 8)), reason='"Bad file descriptor" error') def test_start_parallel_process_kernels(self): self.test_tcp_lifecycle() - thread = threading.Thread(target=self.tcp_lifecycle_with_loop) - # Windows tests needs this target to be picklable: - proc = mp.Process(target=self.test_tcp_lifecycle) - - try: - thread.start() - proc.start() - finally: - thread.join() - proc.join() - - assert proc.exitcode == 0 + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_executor: + future1 = thread_executor.submit(self.tcp_lifecycle_with_loop) + with concurrent.futures.ProcessPoolExecutor(max_workers=1) as process_executor: + # Windows tests needs this target to be picklable: + future2 = process_executor.submit(self.test_tcp_lifecycle) + future2.result() + future1.result() def test_subclass_callables(self): km = self._get_tcp_km_sub() @@ -206,10 +200,10 @@ def test_subclass_callables(self): km.get_kernel(kid).reset_counts() km.reset_counts() km.shutdown_all(now=True) - assert km.call_count('shutdown_kernel') == 0 + assert km.call_count('shutdown_kernel') == 1 assert km.call_count('remove_kernel') == 1 - assert km.call_count('request_shutdown') == 1 - assert km.call_count('finish_shutdown') == 1 + assert km.call_count('request_shutdown') == 0 + assert km.call_count('finish_shutdown') == 0 assert km.call_count('cleanup_resources') == 0 assert kid not in km, f'{kid} not in {km}' @@ -256,7 +250,7 @@ async def _run_lifecycle(km, test_kid=None): assert kid in km.list_kernel_ids() await km.interrupt_kernel(kid) k = km.get_kernel(kid) - assert isinstance(k, KernelManager) + assert isinstance(k, AsyncKernelManager) await km.shutdown_kernel(kid, now=True) assert kid not in km, f'{kid} not in {km}' @@ -389,31 +383,23 @@ def raw_tcp_lifecycle_sync(cls, test_kid=None): async def test_start_parallel_thread_kernels(self): await self.raw_tcp_lifecycle() - thread = threading.Thread(target=self.tcp_lifecycle_with_loop) - thread2 = threading.Thread(target=self.tcp_lifecycle_with_loop) - try: - thread.start() - thread2.start() - finally: - thread.join() - thread2.join() + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as thread_executor: + future1 = thread_executor.submit(self.tcp_lifecycle_with_loop) + future2 = thread_executor.submit(self.tcp_lifecycle_with_loop) + future1.result() + future2.result() @gen_test async def test_start_parallel_process_kernels(self): await self.raw_tcp_lifecycle() - thread = threading.Thread(target=self.tcp_lifecycle_with_loop) - # Windows tests needs this target to be picklable: - proc = mp.Process(target=self.raw_tcp_lifecycle_sync) - - try: - thread.start() - proc.start() - finally: - proc.join() - thread.join() - - assert proc.exitcode == 0 + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_executor: + future1 = thread_executor.submit(self.tcp_lifecycle_with_loop) + with concurrent.futures.ProcessPoolExecutor(max_workers=1) as process_executor: + # Windows tests needs this target to be picklable: + future2 = process_executor.submit(self.raw_tcp_lifecycle_sync) + future2.result() + future1.result() @gen_test async def test_subclass_callables(self): diff --git a/jupyter_client/tests/test_public_api.py b/jupyter_client/tests/test_public_api.py index ab3883d66..5ebf2f3d3 100644 --- a/jupyter_client/tests/test_public_api.py +++ b/jupyter_client/tests/test_public_api.py @@ -9,12 +9,12 @@ def test_kms(): - for base in ("", "Multi"): + for base in ("", "Async", "Multi"): KM = base + "KernelManager" assert KM in dir(jupyter_client) def test_kcs(): - for base in ("", "Blocking"): + for base in ("", "Blocking", "Async"): KM = base + "KernelClient" assert KM in dir(jupyter_client) diff --git a/jupyter_client/utils.py b/jupyter_client/utils.py index 2f49d2103..942932b12 100644 --- a/jupyter_client/utils.py +++ b/jupyter_client/utils.py @@ -1,8 +1,37 @@ """ -Utils vendored from ipython_genutils that should be retired at some point. +utils: +- provides utility wrapeprs to run asynchronous functions in a blocking environment. +- vendor functions from ipython_genutils that should be retired at some point. """ import os +import sys +import asyncio +import inspect +import nest_asyncio + + +if os.name == 'nt' and sys.version_info >= (3, 7): + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + + +def run_sync(coro): + def wrapped(*args, **kwargs): + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + nest_asyncio.apply(loop) + return loop.run_until_complete(coro(*args, **kwargs)) + wrapped.__doc__ = coro.__doc__ + return wrapped + + +async def ensure_async(obj): + if inspect.isawaitable(obj): + return await obj + return obj def _filefind(filename, path_dirs=None): diff --git a/setup.py b/setup.py index 1952004b0..8100b3f09 100644 --- a/setup.py +++ b/setup.py @@ -74,6 +74,7 @@ def run(self): 'pyzmq>=13', 'python-dateutil>=2.1', 'tornado>=4.1', + 'nest-asyncio>=1.5', ], python_requires = '>=3.5', extras_require = { @@ -86,6 +87,7 @@ def run(self): 'pytest-asyncio', 'pytest-timeout', 'pytest', + 'mypy', ], 'doc': open('docs/requirements.txt').read().splitlines(), },