Skip to content

Commit ceb7058

Browse files
Prevent Runtime use over forks (#1208)
* Raise an assertion error when a Runtime is used by client/worker creation/usage. * Add _RuntimeRef to encapsulate default runtime creation. Add Runtime.prevent_default to allow users to more easily enforce that a default runtime is never automatically created * remove blank line to fix linter error * fix use of Self since it's not avaiable in typing until 3.11 * remove references to ForkContext to avoid exploding in Windows * switch type of fixture to Iterator instead of Generator * run formatter * except the correct error type to prevent breakage on windows * Update tests to match error info. Update prevent_default test to demonstrate that you can call prevent_default and then set_default to allow future calls to default. add tests for set_default * fix typo in docstring * remove empty return in Runtime.set_default. Remove _default_created flag in favor of using the optional nature of _default_runtime in _RuntimeRef
1 parent 11e8650 commit ceb7058

File tree

11 files changed

+323
-14
lines changed

11 files changed

+323
-14
lines changed

scripts/gen_bridge_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def generate_rust_service_call(service_descriptor: ServiceDescriptor) -> str:
171171
py: Python<'p>,
172172
call: RpcCall,
173173
) -> PyResult<Bound<'p, PyAny>> {
174+
self.runtime.assert_same_process("use client")?;
174175
use temporal_client::${descriptor_name};
175176
let mut retry_client = self.retry_client.clone();
176177
self.runtime.future_into_py(py, async move {

temporalio/bridge/src/client.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ pub fn connect_client<'a>(
9292
config: ClientConfig,
9393
) -> PyResult<Bound<'a, PyAny>> {
9494
let opts: ClientOptions = config.try_into()?;
95+
runtime_ref.runtime.assert_same_process("create client")?;
9596
let runtime = runtime_ref.runtime.clone();
9697
runtime_ref.runtime.future_into_py(py, async move {
9798
Ok(ClientRef {

temporalio/bridge/src/client_rpc_generated.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ impl ClientRef {
1515
py: Python<'p>,
1616
call: RpcCall,
1717
) -> PyResult<Bound<'p, PyAny>> {
18+
self.runtime.assert_same_process("use client")?;
1819
use temporal_client::WorkflowService;
1920
let mut retry_client = self.retry_client.clone();
2021
self.runtime.future_into_py(py, async move {
@@ -566,6 +567,7 @@ impl ClientRef {
566567
py: Python<'p>,
567568
call: RpcCall,
568569
) -> PyResult<Bound<'p, PyAny>> {
570+
self.runtime.assert_same_process("use client")?;
569571
use temporal_client::OperatorService;
570572
let mut retry_client = self.retry_client.clone();
571573
self.runtime.future_into_py(py, async move {
@@ -628,6 +630,7 @@ impl ClientRef {
628630
}
629631

630632
fn call_cloud_service<'p>(&self, py: Python<'p>, call: RpcCall) -> PyResult<Bound<'p, PyAny>> {
633+
self.runtime.assert_same_process("use client")?;
631634
use temporal_client::CloudService;
632635
let mut retry_client = self.retry_client.clone();
633636
self.runtime.future_into_py(py, async move {
@@ -842,6 +845,7 @@ impl ClientRef {
842845
}
843846

844847
fn call_test_service<'p>(&self, py: Python<'p>, call: RpcCall) -> PyResult<Bound<'p, PyAny>> {
848+
self.runtime.assert_same_process("use client")?;
845849
use temporal_client::TestService;
846850
let mut retry_client = self.retry_client.clone();
847851
self.runtime.future_into_py(py, async move {
@@ -881,6 +885,7 @@ impl ClientRef {
881885
}
882886

883887
fn call_health_service<'p>(&self, py: Python<'p>, call: RpcCall) -> PyResult<Bound<'p, PyAny>> {
888+
self.runtime.assert_same_process("use client")?;
884889
use temporal_client::HealthService;
885890
let mut retry_client = self.retry_client.clone();
886891
self.runtime.future_into_py(py, async move {

temporalio/bridge/src/runtime.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use futures::channel::mpsc::Receiver;
2-
use pyo3::exceptions::{PyRuntimeError, PyValueError};
2+
use pyo3::exceptions::{PyAssertionError, PyRuntimeError, PyValueError};
33
use pyo3::prelude::*;
44
use pythonize::pythonize;
55
use std::collections::HashMap;
@@ -33,6 +33,7 @@ pub struct RuntimeRef {
3333

3434
#[derive(Clone)]
3535
pub(crate) struct Runtime {
36+
pub(crate) pid: u32,
3637
pub(crate) core: Arc<CoreRuntime>,
3738
metrics_call_buffer: Option<Arc<MetricsCallBuffer<BufferedMetricRef>>>,
3839
log_forwarder_handle: Option<Arc<JoinHandle<()>>>,
@@ -173,6 +174,7 @@ pub fn init_runtime(telemetry_config: TelemetryConfig) -> PyResult<RuntimeRef> {
173174

174175
Ok(RuntimeRef {
175176
runtime: Runtime {
177+
pid: std::process::id(),
176178
core: Arc::new(core),
177179
metrics_call_buffer,
178180
log_forwarder_handle,
@@ -197,6 +199,18 @@ impl Runtime {
197199
let _guard = self.core.tokio_handle().enter();
198200
pyo3_async_runtimes::generic::future_into_py::<TokioRuntime, _, T>(py, fut)
199201
}
202+
203+
pub(crate) fn assert_same_process(&self, action: &'static str) -> PyResult<()> {
204+
let current_pid = std::process::id();
205+
if self.pid != current_pid {
206+
Err(PyAssertionError::new_err(format!(
207+
"Cannot {} across forks (original runtime PID is {}, current is {})",
208+
action, self.pid, current_pid,
209+
)))
210+
} else {
211+
Ok(())
212+
}
213+
}
200214
}
201215

202216
impl Drop for Runtime {

temporalio/bridge/src/worker.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,7 @@ pub fn new_worker(
474474
config: WorkerConfig,
475475
) -> PyResult<WorkerRef> {
476476
enter_sync!(runtime_ref.runtime);
477+
runtime_ref.runtime.assert_same_process("create worker")?;
477478
let event_loop_task_locals = Arc::new(OnceLock::new());
478479
let config = convert_worker_config(config, event_loop_task_locals.clone())?;
479480
let worker = temporal_sdk_core::init_worker(
@@ -495,6 +496,9 @@ pub fn new_replay_worker<'a>(
495496
config: WorkerConfig,
496497
) -> PyResult<Bound<'a, PyTuple>> {
497498
enter_sync!(runtime_ref.runtime);
499+
runtime_ref
500+
.runtime
501+
.assert_same_process("create replay worker")?;
498502
let event_loop_task_locals = Arc::new(OnceLock::new());
499503
let config = convert_worker_config(config, event_loop_task_locals.clone())?;
500504
let (history_pusher, stream) = HistoryPusher::new(runtime_ref.runtime.clone());
@@ -519,6 +523,7 @@ pub fn new_replay_worker<'a>(
519523
#[pymethods]
520524
impl WorkerRef {
521525
fn validate<'p>(&self, py: Python<'p>) -> PyResult<Bound<PyAny, 'p>> {
526+
self.runtime.assert_same_process("use worker")?;
522527
let worker = self.worker.as_ref().unwrap().clone();
523528
// Set custom slot supplier task locals so they can run futures.
524529
// Event loop is assumed to be running at this point.
@@ -538,6 +543,7 @@ impl WorkerRef {
538543
}
539544

540545
fn poll_workflow_activation<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
546+
self.runtime.assert_same_process("use worker")?;
541547
let worker = self.worker.as_ref().unwrap().clone();
542548
self.runtime.future_into_py(py, async move {
543549
let bytes = match worker.poll_workflow_activation().await {
@@ -550,6 +556,7 @@ impl WorkerRef {
550556
}
551557

552558
fn poll_activity_task<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
559+
self.runtime.assert_same_process("use worker")?;
553560
let worker = self.worker.as_ref().unwrap().clone();
554561
self.runtime.future_into_py(py, async move {
555562
let bytes = match worker.poll_activity_task().await {
@@ -562,6 +569,7 @@ impl WorkerRef {
562569
}
563570

564571
fn poll_nexus_task<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
572+
self.runtime.assert_same_process("use worker")?;
565573
let worker = self.worker.as_ref().unwrap().clone();
566574
self.runtime.future_into_py(py, async move {
567575
let bytes = match worker.poll_nexus_task().await {

temporalio/runtime.py

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,33 +24,82 @@
2424
import temporalio.bridge.runtime
2525
import temporalio.common
2626

27-
_default_runtime: Optional[Runtime] = None
27+
28+
class _RuntimeRef:
29+
def __init__(
30+
self,
31+
) -> None:
32+
self._default_runtime: Runtime | None = None
33+
self._prevent_default = False
34+
35+
def default(self) -> Runtime:
36+
if not self._default_runtime:
37+
if self._prevent_default:
38+
raise RuntimeError(
39+
"Cannot create default Runtime after Runtime.prevent_default has been called"
40+
)
41+
self._default_runtime = Runtime(telemetry=TelemetryConfig())
42+
self._default_created = True
43+
return self._default_runtime
44+
45+
def prevent_default(self):
46+
if self._default_runtime:
47+
raise RuntimeError(
48+
"Runtime.prevent_default called after default runtime has been created or set"
49+
)
50+
self._prevent_default = True
51+
52+
def set_default(
53+
self, runtime: Runtime, *, error_if_already_set: bool = True
54+
) -> None:
55+
if self._default_runtime and error_if_already_set:
56+
raise RuntimeError("Runtime default already set")
57+
58+
self._default_runtime = runtime
59+
60+
61+
_runtime_ref: _RuntimeRef = _RuntimeRef()
2862

2963

3064
class Runtime:
3165
"""Runtime for Temporal Python SDK.
3266
33-
Users are encouraged to use :py:meth:`default`. It can be set with
67+
Most users are encouraged to use :py:meth:`default`. It can be set with
3468
:py:meth:`set_default`. Every time a new runtime is created, a new internal
3569
thread pool is created.
3670
37-
Runtimes do not work across forks.
71+
Runtimes do not work across forks. Advanced users should consider using
72+
:py:meth:`prevent_default` and `:py:meth`set_default` to ensure each
73+
fork creates it's own runtime.
74+
3875
"""
3976

4077
@classmethod
4178
def default(cls) -> Runtime:
42-
"""Get the default runtime, creating if not already created.
79+
"""Get the default runtime, creating if not already created. If :py:meth:`prevent_default`
80+
is called before this method it will raise a RuntimeError instead of creating a default
81+
runtime.
4382
4483
If the default runtime needs to be different, it should be done with
4584
:py:meth:`set_default` before this is called or ever used.
4685
4786
Returns:
4887
The default runtime.
4988
"""
50-
global _default_runtime
51-
if not _default_runtime:
52-
_default_runtime = cls(telemetry=TelemetryConfig())
53-
return _default_runtime
89+
global _runtime_ref
90+
return _runtime_ref.default()
91+
92+
@classmethod
93+
def prevent_default(cls):
94+
"""Prevent :py:meth:`default` from lazily creating a :py:class:`Runtime`.
95+
96+
Raises a RuntimeError if a default :py:class:`Runtime` has already been created.
97+
98+
Explicitly setting a default runtime with :py:meth:`set_default` bypasses this setting and
99+
future calls to :py:meth:`default` will return the provided runtime.
100+
"""
101+
global _runtime_ref
102+
_runtime_ref.prevent_default()
54103

55104
@staticmethod
56105
def set_default(runtime: Runtime, *, error_if_already_set: bool = True) -> None:
@@ -65,10 +114,8 @@ def set_default(runtime: Runtime, *, error_if_already_set: bool = True) -> None:
65114
error_if_already_set: If True and default is already set, this will
66115
raise a RuntimeError.
67116
"""
68-
global _default_runtime
69-
if _default_runtime and error_if_already_set:
70-
raise RuntimeError("Runtime default already set")
71-
_default_runtime = runtime
117+
global _runtime_ref
118+
_runtime_ref.set_default(runtime, error_if_already_set=error_if_already_set)
72119

73120
def __init__(self, *, telemetry: TelemetryConfig) -> None:
74121
"""Create a default runtime with the given telemetry config.

tests/conftest.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import asyncio
2+
import multiprocessing.context
23
import os
34
import sys
4-
from typing import AsyncGenerator
5+
from typing import AsyncGenerator, Iterator
56

67
import pytest
78
import pytest_asyncio
@@ -133,6 +134,23 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]:
133134
await env.shutdown()
134135

135136

137+
@pytest.fixture(scope="session")
138+
def mp_fork_ctx() -> Iterator[multiprocessing.context.BaseContext | None]:
139+
mp_ctx = None
140+
try:
141+
mp_ctx = multiprocessing.get_context("fork")
142+
except ValueError:
143+
pass
144+
145+
try:
146+
yield mp_ctx
147+
finally:
148+
if mp_ctx:
149+
for p in mp_ctx.active_children():
150+
p.terminate()
151+
p.join()
152+
153+
136154
@pytest_asyncio.fixture
137155
async def client(env: WorkflowEnvironment) -> Client:
138156
return env.client

tests/helpers/fork.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import multiprocessing
5+
import multiprocessing.context
6+
import sys
7+
from dataclasses import dataclass
8+
from typing import Any
9+
10+
import pytest
11+
12+
13+
@dataclass
14+
class _ForkTestResult:
15+
status: str
16+
err_name: str | None
17+
err_msg: str | None
18+
19+
def __eq__(self, value: object) -> bool:
20+
if not isinstance(value, _ForkTestResult):
21+
return False
22+
23+
valid_err_msg = False
24+
25+
if self.err_msg and value.err_msg:
26+
valid_err_msg = (
27+
self.err_msg in value.err_msg or value.err_msg in self.err_msg
28+
)
29+
30+
return (
31+
value.status == self.status
32+
and value.err_name == value.err_name
33+
and valid_err_msg
34+
)
35+
36+
@staticmethod
37+
def assertion_error(message: str) -> _ForkTestResult:
38+
return _ForkTestResult(
39+
status="error", err_name="AssertionError", err_msg=message
40+
)
41+
42+
43+
class _TestFork:
44+
_expected: _ForkTestResult
45+
46+
async def coro(self) -> Any:
47+
raise NotImplementedError()
48+
49+
def entry(self):
50+
event_loop = asyncio.new_event_loop()
51+
asyncio.set_event_loop(event_loop)
52+
try:
53+
event_loop.run_until_complete(self.coro())
54+
payload = _ForkTestResult(status="ok", err_name=None, err_msg=None)
55+
except BaseException as err:
56+
payload = _ForkTestResult(
57+
status="error", err_name=err.__class__.__name__, err_msg=str(err)
58+
)
59+
60+
self._child_conn.send(payload)
61+
self._child_conn.close()
62+
63+
def run(self, mp_fork_context: multiprocessing.context.BaseContext | None):
64+
process_factory = getattr(mp_fork_context, "Process", None)
65+
66+
if not mp_fork_context or not process_factory:
67+
pytest.skip("fork context not available")
68+
69+
self._parent_conn, self._child_conn = mp_fork_context.Pipe(duplex=False)
70+
# start fork
71+
child_process = process_factory(target=self.entry, args=(), daemon=False)
72+
child_process.start()
73+
# close parent's handle on child_conn
74+
self._child_conn.close()
75+
76+
# get run info from pipe
77+
payload = self._parent_conn.recv()
78+
self._parent_conn.close()
79+
80+
assert payload == self._expected

0 commit comments

Comments
 (0)