Skip to content

Commit 629c763

Browse files
author
Golf Player
committed
Add basic hooks during execution
This will enable tracking of execution process without subclassing the way papermill does.
1 parent 202e046 commit 629c763

File tree

3 files changed

+62
-13
lines changed

3 files changed

+62
-13
lines changed

.bumpversion.cfg

-8
This file was deleted.

nbclient/client.py

+48-4
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
CellExecutionComplete,
2929
CellExecutionError
3030
)
31-
from .util import run_sync, ensure_async
31+
from .util import run_sync, ensure_async, run_hook
3232
from .output_widget import OutputWidget
3333

3434

@@ -227,6 +227,45 @@ class NotebookClient(LoggingConfigurable):
227227

228228
kernel_manager_class: KernelManager = Type(config=True, help='The kernel manager class to use.')
229229

230+
on_execution_start: t.Optional[t.Callable] = Any(
231+
default_value=None,
232+
allow_none=True,
233+
help=dedent("""
234+
Called after the kernel manager and kernel client are setup, and cells
235+
are about to execute.
236+
Called with kwargs `kernel_id`.
237+
"""),
238+
).tag(config=True)
239+
240+
on_cell_start: t.Optional[t.Callable] = Any(
241+
default_value=None,
242+
allow_none=True,
243+
help=dedent("""
244+
A callable which executes before a cell is executed.
245+
Called with kwargs `cell`, and `cell_index`.
246+
"""),
247+
).tag(config=True)
248+
249+
on_cell_complete: t.Optional[t.Callable] = Any(
250+
default_value=None,
251+
allow_none=True,
252+
help=dedent("""
253+
A callable which executes after a cell execution is complete. It is
254+
called even when a cell results in a failure.
255+
Called with kwargs `cell`, and `cell_index`.
256+
"""),
257+
).tag(config=True)
258+
259+
on_cell_error: t.Optional[t.Callable] = Any(
260+
default_value=None,
261+
allow_none=True,
262+
help=dedent("""
263+
A callable which executes when a cell execution results in an error.
264+
This is executed even if errors are suppressed with `cell_allows_errors`.
265+
Called with kwargs `cell`, and `cell_index`.
266+
"""),
267+
).tag(config=True)
268+
230269
@default('kernel_manager_class')
231270
def _kernel_manager_class_default(self) -> KernelManager:
232271
"""Use a dynamic default to avoid importing jupyter_client at startup"""
@@ -412,6 +451,7 @@ async def async_start_new_kernel_client(self, **kwargs) -> t.Tuple[KernelClient,
412451
await self._async_cleanup_kernel()
413452
raise
414453
self.kc.allow_stdin = False
454+
run_hook(self.on_execution_start, kernel_id=kernel_id)
415455
return self.kc, kernel_id
416456

417457
start_new_kernel_client = run_sync(async_start_new_kernel_client)
@@ -702,14 +742,16 @@ def _passed_deadline(self, deadline: int) -> bool:
702742
def _check_raise_for_error(
703743
self,
704744
cell: NotebookNode,
745+
cell_index: int,
705746
exec_reply: t.Optional[t.Dict]) -> None:
706747

707748
cell_allows_errors = self.allow_errors or "raises-exception" in cell.metadata.get(
708749
"tags", []
709750
)
710751

711-
if self.force_raise_errors or not cell_allows_errors:
712-
if (exec_reply is not None) and exec_reply['content']['status'] == 'error':
752+
if (exec_reply is not None) and exec_reply['content']['status'] == 'error':
753+
run_hook(self.on_cell_error, cell=cell, cell_index=cell_index)
754+
if self.force_raise_errors or not cell_allows_errors:
713755
raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content'])
714756

715757
async def async_execute_cell(
@@ -760,13 +802,15 @@ async def async_execute_cell(
760802
cell['metadata']['execution'] = {}
761803

762804
self.log.debug("Executing cell:\n%s", cell.source)
805+
run_hook(self.on_cell_start, cell=cell, cell_index=cell_index)
763806
parent_msg_id = await ensure_async(
764807
self.kc.execute(
765808
cell.source,
766809
store_history=store_history,
767810
stop_on_error=not self.allow_errors
768811
)
769812
)
813+
run_hook(self.on_cell_complete, cell=cell, cell_index=cell_index)
770814
# We launched a code cell to execute
771815
self.code_cells_executed += 1
772816
exec_timeout = self._get_timeout(cell)
@@ -792,7 +836,7 @@ async def async_execute_cell(
792836

793837
if execution_count:
794838
cell['execution_count'] = execution_count
795-
self._check_raise_for_error(cell, exec_reply)
839+
self._check_raise_for_error(cell, cell_index, exec_reply)
796840
self.nb['cells'][cell_index] = cell
797841
return cell
798842

nbclient/util.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import asyncio
77
import sys
88
import inspect
9-
from typing import Callable, Awaitable, Any, Union
9+
from typing import Callable, Awaitable, Any, Union, Optional
10+
from functools import partial
1011

1112

1213
def check_ipython() -> None:
@@ -91,3 +92,15 @@ async def ensure_async(obj: Union[Awaitable, Any]) -> Any:
9192
return result
9293
# obj doesn't need to be awaited
9394
return obj
95+
96+
97+
def run_hook(hook: Optional[Callable], **kwargs) -> None:
98+
if hook is None:
99+
return
100+
if inspect.iscoroutinefunction(hook):
101+
future = hook(**kwargs)
102+
else:
103+
loop = asyncio.get_event_loop()
104+
hook_with_kwargs = partial(hook, **kwargs)
105+
future = loop.run_in_executor(None, hook_with_kwargs)
106+
asyncio.ensure_future(future)

0 commit comments

Comments
 (0)