From ad97a60002e93f3b260f4f69df5e20b2652019f8 Mon Sep 17 00:00:00 2001
From: jakkdl <h6+github@pm.me>
Date: Wed, 19 Jul 2023 14:05:24 +0200
Subject: [PATCH 1/6] add KernelManager.exit_status

---
 jupyter_client/manager.py   | 11 ++++++++++
 tests/test_kernelmanager.py | 40 +++++++++++++++++++++++++++++++++++++
 2 files changed, 51 insertions(+)

diff --git a/jupyter_client/manager.py b/jupyter_client/manager.py
index f04bd987..8626407d 100644
--- a/jupyter_client/manager.py
+++ b/jupyter_client/manager.py
@@ -655,6 +655,17 @@ async def _async_is_alive(self) -> bool:
 
     is_alive = run_sync(_async_is_alive)
 
+    async def _async_exit_status(self) -> int | None:
+        """Returns 0 if there's no kernel or it exited gracefully,
+        None if the kernel is running, or a negative value `-N` if the
+        kernel was killed by signal `N` (posix only)."""
+        if not self.has_kernel:
+            return 0
+        assert self.provisioner is not None
+        return await self.provisioner.poll()
+
+    exit_status = run_sync(_async_exit_status)
+
     async def _async_wait(self, pollinterval: float = 0.1) -> None:
         # Use busy loop at 100ms intervals, polling until the process is
         # not alive.  If we find the process is no longer alive, complete
diff --git a/tests/test_kernelmanager.py b/tests/test_kernelmanager.py
index f2d749eb..989d5198 100644
--- a/tests/test_kernelmanager.py
+++ b/tests/test_kernelmanager.py
@@ -160,6 +160,46 @@ async def test_async_signal_kernel_subprocesses(self, name, install, expected):
         assert km._shutdown_status in expected
 
 
+class TestKernelManagerExitStatus:
+    @pytest.mark.skipif(sys.platform == "win32", reason="Windows doesn't support signals")
+    @pytest.mark.parametrize('_signal', [signal.SIGHUP, signal.SIGTERM, signal.SIGKILL])
+    async def test_exit_status(self, _signal):
+        # install kernel
+        _install_kernel(name="test_exit_status")
+
+        # start kernel
+        km, kc = start_new_kernel(kernel_name="test_exit_status")
+
+        # stop restarter - not needed?
+        # km.stop_restarter()
+
+        # check that process is running
+        assert km.exit_status() is None
+
+        # get the provisioner
+        # send signal
+        provisioner = km.provisioner
+        assert provisioner is not None
+        assert provisioner.has_process
+        await provisioner.send_signal(_signal)
+
+        # wait for the process to exit
+        try:
+            await asyncio.wait_for(km._async_wait(), timeout=3.0)
+        except TimeoutError:
+            assert False, f'process never stopped for signal {signal}'
+
+        # check that the signal is correct
+        assert km.exit_status() == -_signal
+
+        # doing a proper shutdown now wipes the status, might be bad?
+        km.shutdown_kernel(now=True)
+        assert km.exit_status() == 0
+
+        # stop channels so cleanup doesn't complain
+        kc.stop_channels()
+
+
 class TestKernelManager:
     def test_lifecycle(self, km):
         km.start_kernel(stdout=PIPE, stderr=PIPE)

From b6b5cf0a16992d281d2f80cdf15a7fe79631386b Mon Sep 17 00:00:00 2001
From: jakkdl <h6+github@pm.me>
Date: Thu, 20 Jul 2023 12:16:59 +0200
Subject: [PATCH 2/6] fix CI, test signals on windows

---
 jupyter_client/manager.py   | 2 +-
 tests/test_kernelmanager.py | 3 +--
 2 files changed, 2 insertions(+), 3 deletions(-)

diff --git a/jupyter_client/manager.py b/jupyter_client/manager.py
index 8626407d..7ad2002b 100644
--- a/jupyter_client/manager.py
+++ b/jupyter_client/manager.py
@@ -655,7 +655,7 @@ async def _async_is_alive(self) -> bool:
 
     is_alive = run_sync(_async_is_alive)
 
-    async def _async_exit_status(self) -> int | None:
+    async def _async_exit_status(self) -> t.Optional[int]:
         """Returns 0 if there's no kernel or it exited gracefully,
         None if the kernel is running, or a negative value `-N` if the
         kernel was killed by signal `N` (posix only)."""
diff --git a/tests/test_kernelmanager.py b/tests/test_kernelmanager.py
index 989d5198..65b787a1 100644
--- a/tests/test_kernelmanager.py
+++ b/tests/test_kernelmanager.py
@@ -161,8 +161,7 @@ async def test_async_signal_kernel_subprocesses(self, name, install, expected):
 
 
 class TestKernelManagerExitStatus:
-    @pytest.mark.skipif(sys.platform == "win32", reason="Windows doesn't support signals")
-    @pytest.mark.parametrize('_signal', [signal.SIGHUP, signal.SIGTERM, signal.SIGKILL])
+    @pytest.mark.parametrize('_signal', [signal.SIGILL, signal.SIGSEGV, signal.SIGTERM])
     async def test_exit_status(self, _signal):
         # install kernel
         _install_kernel(name="test_exit_status")

From b6816eaf47291fc0ae40406c80300f5e6f1c7027 Mon Sep 17 00:00:00 2001
From: jakkdl <h6+github@pm.me>
Date: Thu, 17 Aug 2023 12:18:22 +0200
Subject: [PATCH 3/6] add accepts_exit_code parameter to add_restart_callback
 and make restarter pass the exit code to callback when restarting

---
 jupyter_client/ioloop/restarter.py |  8 +++----
 jupyter_client/manager.py          | 10 ++++++--
 jupyter_client/restarter.py        | 37 ++++++++++++++++++++++--------
 tests/test_kernelmanager.py        |  1 +
 tests/test_restarter.py            |  8 +++----
 5 files changed, 45 insertions(+), 19 deletions(-)

diff --git a/jupyter_client/ioloop/restarter.py b/jupyter_client/ioloop/restarter.py
index d0c70396..6a3b9fe2 100644
--- a/jupyter_client/ioloop/restarter.py
+++ b/jupyter_client/ioloop/restarter.py
@@ -55,9 +55,9 @@ async def poll(self):
         """Poll the kernel."""
         if self.debug:
             self.log.debug("Polling kernel...")
-        is_alive = await self.kernel_manager.is_alive()
+        exit_status = await self.kernel_manager.exit_status()
         now = time.time()
-        if not is_alive:
+        if exit_status is not None:
             self._last_dead = now
             if self._restarting:
                 self._restart_count += 1
@@ -66,7 +66,7 @@ async def poll(self):
 
             if self._restart_count > self.restart_limit:
                 self.log.warning("AsyncIOLoopKernelRestarter: restart failed")
-                self._fire_callbacks("dead")
+                self._fire_callbacks("dead", exit_status)
                 self._restarting = False
                 self._restart_count = 0
                 self.stop()
@@ -78,7 +78,7 @@ async def poll(self):
                     self.restart_limit,
                     "new" if newports else "keep",
                 )
-                self._fire_callbacks("restart")
+                self._fire_callbacks("restart", exit_status)
                 await self.kernel_manager.restart_kernel(now=True, newports=newports)
                 self._restarting = True
         else:
diff --git a/jupyter_client/manager.py b/jupyter_client/manager.py
index 7ad2002b..ecd69a6d 100644
--- a/jupyter_client/manager.py
+++ b/jupyter_client/manager.py
@@ -231,11 +231,17 @@ def stop_restarter(self) -> None:
         """Stop the kernel restarter."""
         pass
 
-    def add_restart_callback(self, callback: t.Callable, event: str = "restart") -> None:
+    def add_restart_callback(
+        self,
+        callback: t.Callable[[], object] | t.Callable[[int], object],
+        event: str = "restart",
+        *,
+        accepts_exit_code: bool = False,
+    ) -> None:
         """Register a callback to be called when a kernel is restarted"""
         if self._restarter is None:
             return
-        self._restarter.add_callback(callback, event)
+        self._restarter.add_callback(callback, event, accepts_exit_code=accepts_exit_code)
 
     def remove_restart_callback(self, callback: t.Callable, event: str = "restart") -> None:
         """Unregister a callback to be called when a kernel is restarted"""
diff --git a/jupyter_client/restarter.py b/jupyter_client/restarter.py
index 194ba907..26bf8da6 100644
--- a/jupyter_client/restarter.py
+++ b/jupyter_client/restarter.py
@@ -7,7 +7,9 @@
 """
 # Copyright (c) Jupyter Development Team.
 # Distributed under the terms of the Modified BSD License.
+import functools
 import time
+import typing as t
 
 from traitlets import Bool, Dict, Float, Instance, Integer, default
 from traitlets.config.configurable import LoggingConfigurable
@@ -55,7 +57,8 @@ class KernelRestarter(LoggingConfigurable):
     def _default_last_dead(self):
         return time.time()
 
-    callbacks = Dict()
+    # traitlets.Dict is not typed generic
+    callbacks: t.Dict[str, t.List[t.Callable[[int], object]]] = Dict()  # type: ignore
 
     def _callbacks_default(self):
         return {"restart": [], "dead": []}
@@ -70,8 +73,14 @@ def stop(self):
         msg = "Must be implemented in a subclass"
         raise NotImplementedError(msg)
 
-    def add_callback(self, f, event="restart"):
-        """register a callback to fire on a particular event
+    def add_callback(
+        self,
+        f: t.Callable[[], object] | t.Callable[[int], object],
+        event: str = "restart",
+        *,
+        accepts_exit_code: bool = False,
+    ) -> None:
+        """register a callback to fire on a particular event. If ``accepts_exit_code`` is set, the callable will be passed the exit code as reported by `KernelManager.exit_status`
 
         Possible values for event:
 
@@ -79,7 +88,16 @@ def add_callback(self, f, event="restart"):
           'dead': restart has failed, kernel will be left dead.
 
         """
-        self.callbacks[event].append(f)
+        if not accepts_exit_code:
+
+            @functools.wraps(f)
+            def ignore_exit_code(code: int) -> object:
+                return f()  # type: ignore[call-arg]
+
+            f = ignore_exit_code
+            self.callbacks[event].append(f)
+
+        self.callbacks[event].append(f)  # type: ignore[arg-type]
 
     def remove_callback(self, f, event="restart"):
         """unregister a callback to fire on a particular event
@@ -95,11 +113,11 @@ def remove_callback(self, f, event="restart"):
         except ValueError:
             pass
 
-    def _fire_callbacks(self, event):
+    def _fire_callbacks(self, event, status):
         """fire our callbacks for a particular event"""
         for callback in self.callbacks[event]:
             try:
-                callback()
+                callback(status)
             except Exception:
                 self.log.error(
                     "KernelRestarter: %s callback %r failed",
@@ -115,7 +133,8 @@ def poll(self):
             self.log.debug("Kernel shutdown in progress...")
             return
         now = time.time()
-        if not self.kernel_manager.is_alive():
+        status = self.kernel_manager.exit_status()
+        if status is not None:
             self._last_dead = now
             if self._restarting:
                 self._restart_count += 1
@@ -124,7 +143,7 @@ def poll(self):
 
             if self._restart_count > self.restart_limit:
                 self.log.warning("KernelRestarter: restart failed")
-                self._fire_callbacks("dead")
+                self._fire_callbacks("dead", status)
                 self._restarting = False
                 self._restart_count = 0
                 self.stop()
@@ -136,7 +155,7 @@ def poll(self):
                     self.restart_limit,
                     "new" if newports else "keep",
                 )
-                self._fire_callbacks("restart")
+                self._fire_callbacks("restart", status)
                 self.kernel_manager.restart_kernel(now=True, newports=newports)
                 self._restarting = True
         else:
diff --git a/tests/test_kernelmanager.py b/tests/test_kernelmanager.py
index 65b787a1..986bdda1 100644
--- a/tests/test_kernelmanager.py
+++ b/tests/test_kernelmanager.py
@@ -161,6 +161,7 @@ async def test_async_signal_kernel_subprocesses(self, name, install, expected):
 
 
 class TestKernelManagerExitStatus:
+    @pytest.mark.skipif(sys.platform == "win32", reason="Windows doesn't support signals")
     @pytest.mark.parametrize('_signal', [signal.SIGILL, signal.SIGSEGV, signal.SIGTERM])
     async def test_exit_status(self, _signal):
         # install kernel
diff --git a/tests/test_restarter.py b/tests/test_restarter.py
index b216842f..bd428162 100644
--- a/tests/test_restarter.py
+++ b/tests/test_restarter.py
@@ -88,7 +88,7 @@ def debug_logging():
 @win_skip
 async def test_restart_check(config, install_kernel, debug_logging):
     """Test that the kernel is restarted and recovers"""
-    # If this test failes, run it with --log-cli-level=DEBUG to inspect
+    # If this test fails, run it with --log-cli-level=DEBUG to inspect
     N_restarts = 1
     config.KernelRestarter.restart_limit = N_restarts
     config.KernelRestarter.debug = True
@@ -144,7 +144,7 @@ def cb():
 @win_skip
 async def test_restarter_gives_up(config, install_fail_kernel, debug_logging):
     """Test that the restarter gives up after reaching the restart limit"""
-    # If this test failes, run it with --log-cli-level=DEBUG to inspect
+    # If this test fails, run it with --log-cli-level=DEBUG to inspect
     N_restarts = 1
     config.KernelRestarter.restart_limit = N_restarts
     config.KernelRestarter.debug = True
@@ -188,7 +188,7 @@ def on_death():
 
 async def test_async_restart_check(config, install_kernel, debug_logging):
     """Test that the kernel is restarted and recovers"""
-    # If this test failes, run it with --log-cli-level=DEBUG to inspect
+    # If this test fails, run it with --log-cli-level=DEBUG to inspect
     N_restarts = 1
     config.KernelRestarter.restart_limit = N_restarts
     config.KernelRestarter.debug = True
@@ -243,7 +243,7 @@ def cb():
 
 async def test_async_restarter_gives_up(config, install_slow_fail_kernel, debug_logging):
     """Test that the restarter gives up after reaching the restart limit"""
-    # If this test failes, run it with --log-cli-level=DEBUG to inspect
+    # If this test fails, run it with --log-cli-level=DEBUG to inspect
     N_restarts = 2
     config.KernelRestarter.restart_limit = N_restarts
     config.KernelRestarter.debug = True

From 5f7214c54ea7587588f19f16ae04b6a287f3698f Mon Sep 17 00:00:00 2001
From: jakkdl <h6+github@pm.me>
Date: Thu, 17 Aug 2023 13:05:28 +0200
Subject: [PATCH 4/6] the wrapper seems to have broken tests, so doing a
 slightly messier implementation instead

---
 jupyter_client/restarter.py | 23 +++++++++--------------
 1 file changed, 9 insertions(+), 14 deletions(-)

diff --git a/jupyter_client/restarter.py b/jupyter_client/restarter.py
index 26bf8da6..a0564879 100644
--- a/jupyter_client/restarter.py
+++ b/jupyter_client/restarter.py
@@ -7,7 +7,6 @@
 """
 # Copyright (c) Jupyter Development Team.
 # Distributed under the terms of the Modified BSD License.
-import functools
 import time
 import typing as t
 
@@ -58,7 +57,7 @@ def _default_last_dead(self):
         return time.time()
 
     # traitlets.Dict is not typed generic
-    callbacks: t.Dict[str, t.List[t.Callable[[int], object]]] = Dict()  # type: ignore
+    callbacks: t.Dict[str, t.List[t.Tuple[t.Callable[[int], object], t.Literal[True]] | t.Tuple[t.Callable[[], object], t.Literal[False]]]] = Dict()  # type: ignore[assignment]
 
     def _callbacks_default(self):
         return {"restart": [], "dead": []}
@@ -88,16 +87,8 @@ def add_callback(
           'dead': restart has failed, kernel will be left dead.
 
         """
-        if not accepts_exit_code:
-
-            @functools.wraps(f)
-            def ignore_exit_code(code: int) -> object:
-                return f()  # type: ignore[call-arg]
-
-            f = ignore_exit_code
-            self.callbacks[event].append(f)
-
-        self.callbacks[event].append(f)  # type: ignore[arg-type]
+        # no dynamic validation that the callable is valid in accordance to accepts_exit_code
+        self.callbacks[event].append((f, accepts_exit_code))  # type: ignore[arg-type]
 
     def remove_callback(self, f, event="restart"):
         """unregister a callback to fire on a particular event
@@ -115,14 +106,18 @@ def remove_callback(self, f, event="restart"):
 
     def _fire_callbacks(self, event, status):
         """fire our callbacks for a particular event"""
+        # unpacking in the loop breaks the connection between the variables for mypy
         for callback in self.callbacks[event]:
             try:
-                callback(status)
+                if callback[1] is True:
+                    callback[0](status)
+                else:
+                    callback[0]()
             except Exception:
                 self.log.error(
                     "KernelRestarter: %s callback %r failed",
                     event,
-                    callback,
+                    callback[0],
                     exc_info=True,
                 )
 

From 25586f59d63176a7ce8ca89df7c6c57179dbbed8 Mon Sep 17 00:00:00 2001
From: jakkdl <h6+github@pm.me>
Date: Thu, 17 Aug 2023 13:15:33 +0200
Subject: [PATCH 5/6] add overloads for add_callback

---
 jupyter_client/restarter.py | 22 +++++++++++++++++++++-
 1 file changed, 21 insertions(+), 1 deletion(-)

diff --git a/jupyter_client/restarter.py b/jupyter_client/restarter.py
index a0564879..fcdcc411 100644
--- a/jupyter_client/restarter.py
+++ b/jupyter_client/restarter.py
@@ -72,6 +72,26 @@ def stop(self):
         msg = "Must be implemented in a subclass"
         raise NotImplementedError(msg)
 
+    @t.overload
+    def add_callback(
+        self,
+        f: t.Callable[[int], object],
+        event: str = "restart",
+        *,
+        accepts_exit_code: t.Literal[True],
+    ) -> None:
+        ...
+
+    @t.overload
+    def add_callback(
+        self,
+        f: t.Callable[[], object],
+        event: str = "restart",
+        *,
+        accepts_exit_code: t.Literal[False] = False,
+    ) -> None:
+        ...
+
     def add_callback(
         self,
         f: t.Callable[[], object] | t.Callable[[int], object],
@@ -87,7 +107,7 @@ def add_callback(
           'dead': restart has failed, kernel will be left dead.
 
         """
-        # no dynamic validation that the callable is valid in accordance to accepts_exit_code
+        # the type correlation from overloads is not tracked to here by mypy
         self.callbacks[event].append((f, accepts_exit_code))  # type: ignore[arg-type]
 
     def remove_callback(self, f, event="restart"):

From 117b66305e0529ce0d9278ad79130e3db2d30ed3 Mon Sep 17 00:00:00 2001
From: jakkdl <h6+github@pm.me>
Date: Thu, 17 Aug 2023 13:18:56 +0200
Subject: [PATCH 6/6] add test_restart_check_exit_status

---
 tests/test_restarter.py | 62 +++++++++++++++++++++++++++++++++++++++--
 1 file changed, 60 insertions(+), 2 deletions(-)

diff --git a/tests/test_restarter.py b/tests/test_restarter.py
index bd428162..078a300e 100644
--- a/tests/test_restarter.py
+++ b/tests/test_restarter.py
@@ -4,7 +4,9 @@
 import asyncio
 import json
 import os
+import signal
 import sys
+import typing as t
 from concurrent.futures import Future
 
 import pytest
@@ -95,9 +97,9 @@ async def test_restart_check(config, install_kernel, debug_logging):
     km = IOLoopKernelManager(kernel_name=install_kernel, config=config)
 
     cbs = 0
-    restarts: list = [Future() for i in range(N_restarts)]
+    restarts: t.List[Future[bool]] = [Future() for i in range(N_restarts)]
 
-    def cb():
+    def cb() -> None:
         nonlocal cbs
         if cbs >= N_restarts:
             raise RuntimeError("Kernel restarted more than %d times!" % N_restarts)
@@ -141,6 +143,62 @@ def cb():
         assert km.context.closed
 
 
+@win_skip
+async def test_restart_check_exit_status(config, install_kernel, debug_logging):
+    """Test that the kernel is restarted and recovers, and validates the exit code."""
+    # If this test fails, run it with --log-cli-level=DEBUG to inspect
+    N_restarts = 1
+    config.KernelRestarter.restart_limit = N_restarts
+    config.KernelRestarter.debug = True
+    km = IOLoopKernelManager(kernel_name=install_kernel, config=config)
+
+    cbs = 0
+    restarts: t.List[Future[int]] = [Future() for i in range(N_restarts)]
+
+    def cb(exit_status: int) -> None:
+        nonlocal cbs
+        if cbs >= N_restarts:
+            raise RuntimeError("Kernel restarted more than %d times!" % N_restarts)
+        restarts[cbs].set_result(exit_status)
+        cbs += 1
+
+    try:
+        km.start_kernel()
+        km.add_restart_callback(cb, 'restart', accepts_exit_code=True)
+    except BaseException:
+        if km.has_kernel:
+            km.shutdown_kernel()
+        raise
+
+    try:
+        for i in range(N_restarts + 1):
+            kc = km.client()
+            kc.start_channels()
+            kc.wait_for_ready(timeout=60)
+            kc.stop_channels()
+            if i < N_restarts:
+                # Kill without cleanup to simulate crash:
+                assert km.provisioner is not None
+                await km.provisioner.kill()
+                assert restarts[i].result() == -signal.SIGKILL
+                # Wait for kill + restart
+                max_wait = 10.0
+                waited = 0.0
+                while waited < max_wait and km.is_alive():
+                    await asyncio.sleep(0.1)
+                    waited += 0.1
+                while waited < max_wait and not km.is_alive():
+                    await asyncio.sleep(0.1)
+                    waited += 0.1
+
+        assert cbs == N_restarts
+        assert km.is_alive()
+
+    finally:
+        km.shutdown_kernel(now=True)
+        assert km.context.closed
+
+
 @win_skip
 async def test_restarter_gives_up(config, install_fail_kernel, debug_logging):
     """Test that the restarter gives up after reaching the restart limit"""