Skip to content

Commit 479a6ea

Browse files
committed
feat: expose coroutine constructor
1 parent 3a53e6d commit 479a6ea

File tree

7 files changed

+124
-66
lines changed

7 files changed

+124
-66
lines changed

guide/src/async-await.md

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,24 @@ To make a Rust future awaitable in Python, PyO3 defines a [`Coroutine`]({{#PYO3_
9696

9797
Each `coroutine.send` call is translated to a `Future::poll` call. If a [`CancelHandle`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.CancelHandle.html) parameter is declared, the exception passed to `coroutine.throw` call is stored in it and can be retrieved with [`CancelHandle::cancelled`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.CancelHandle.html#method.cancelled); otherwise, it cancels the Rust future, and the exception is reraised;
9898

99-
*The type does not yet have a public constructor until the design is finalized.*
99+
Coroutine can also be instantiated directly
100+
101+
```rust
102+
# # ![allow(dead_code)]
103+
use pyo3::prelude::*;
104+
use pyo3::coroutine::{CancelHandle, Coroutine};
105+
106+
#[pyfunction]
107+
fn new_coroutine(py: Python<'_>) -> Coroutine {
108+
let mut cancel = CancelHandle::new();
109+
let throw_callback = cancel.throw_callback();
110+
let future = async move {
111+
cancel.cancelled().await;
112+
PyResult::Ok(())
113+
};
114+
Coroutine::new(pyo3::intern!(py, "my_coro"), future)
115+
.with_qualname_prefix("MyClass")
116+
.with_throw_callback(throw_callback)
117+
.with_allow_threads(true)
118+
}
119+
```

newsfragments/3613.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Expose `Coroutine` constructor

pyo3-macros-backend/src/method.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -503,13 +503,13 @@ impl<'a> FnSpec<'a> {
503503
};
504504
let mut call = quote! {{
505505
let future = #future;
506-
_pyo3::impl_::coroutine::new_coroutine(
506+
_pyo3::coroutine::Coroutine::new(
507507
_pyo3::intern!(py, stringify!(#python_name)),
508-
#qualname_prefix,
509-
#throw_callback,
510-
#allow_threads,
511508
async move { _pyo3::impl_::wrap::OkWrap::wrap(future.await) },
512509
)
510+
.with_qualname_prefix(#qualname_prefix)
511+
.with_throw_callback(#throw_callback)
512+
.with_allow_threads(#allow_threads)
513513
}};
514514
if cancel_handle.is_some() {
515515
call = quote! {{

src/coroutine.rs

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ use std::{
1111
use pyo3_macros::{pyclass, pymethods};
1212

1313
use crate::{
14-
coroutine::{cancel::ThrowCallback, waker::CoroutineWaker},
15-
exceptions::{PyAttributeError, PyRuntimeError, PyStopIteration},
14+
coroutine::waker::CoroutineWaker,
15+
exceptions::{PyRuntimeError, PyStopIteration},
1616
panic::PanicException,
1717
pyclass::IterNextOutput,
1818
types::PyString,
@@ -27,7 +27,7 @@ pub(crate) mod cancel;
2727
mod trio;
2828
pub(crate) mod waker;
2929

30-
pub use cancel::CancelHandle;
30+
pub use cancel::{CancelHandle, ThrowCallback};
3131

3232
const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine";
3333

@@ -64,11 +64,11 @@ where
6464
/// Python coroutine wrapping a [`Future`].
6565
#[pyclass(crate = "crate")]
6666
pub struct Coroutine {
67-
name: Option<Py<PyString>>,
67+
future: Option<Pin<Box<dyn CoroutineFuture + Send>>>,
68+
name: Py<PyString>,
6869
qualname_prefix: Option<&'static str>,
6970
throw_callback: Option<ThrowCallback>,
7071
allow_threads: bool,
71-
future: Option<Pin<Box<dyn CoroutineFuture + Send>>>,
7272
waker: Option<Arc<CoroutineWaker>>,
7373
}
7474

@@ -78,29 +78,44 @@ impl Coroutine {
7878
/// Coroutine `send` polls the wrapped future, ignoring the value passed
7979
/// (should always be `None` anyway).
8080
///
81-
/// `Coroutine `throw` drop the wrapped future and reraise the exception passed
82-
pub(crate) fn new<F, T, E>(
83-
name: Option<Py<PyString>>,
84-
qualname_prefix: Option<&'static str>,
85-
throw_callback: Option<ThrowCallback>,
86-
allow_threads: bool,
87-
future: F,
88-
) -> Self
81+
/// `Coroutine `throw` drop the wrapped future and reraise the exception passed.
82+
pub fn new<F, T, E>(name: impl Into<Py<PyString>>, future: F) -> Self
8983
where
9084
F: Future<Output = Result<T, E>> + Send + 'static,
9185
T: IntoPy<PyObject> + Send,
9286
E: Into<PyErr> + Send,
9387
{
9488
Self {
95-
name,
96-
qualname_prefix,
97-
throw_callback,
98-
allow_threads,
9989
future: Some(Box::pin(future)),
90+
name: name.into(),
91+
qualname_prefix: None,
92+
throw_callback: None,
93+
allow_threads: false,
10094
waker: None,
10195
}
10296
}
10397

98+
/// Set a prefix for `__qualname__`, which will be joined with a "."
99+
pub fn with_qualname_prefix(mut self, prefix: impl Into<Option<&'static str>>) -> Self {
100+
self.qualname_prefix = prefix.into();
101+
self
102+
}
103+
104+
/// Register a callback for coroutine `throw` method.
105+
///
106+
/// The exception passed to `throw` is then redirected to this callback, notifying the
107+
/// associated [`CancelHandle`], without being reraised.
108+
pub fn with_throw_callback(mut self, callback: impl Into<Option<ThrowCallback>>) -> Self {
109+
self.throw_callback = callback.into();
110+
self
111+
}
112+
113+
/// Release the GIL while polling the future wrapped.
114+
pub fn with_allow_threads(mut self, allow_threads: bool) -> Self {
115+
self.allow_threads = allow_threads;
116+
self
117+
}
118+
104119
fn poll_inner(
105120
&mut self,
106121
py: Python<'_>,
@@ -169,22 +184,18 @@ pub(crate) fn iter_result(result: IterNextOutput<PyObject, PyObject>) -> PyResul
169184
#[pymethods(crate = "crate")]
170185
impl Coroutine {
171186
#[getter]
172-
fn __name__(&self, py: Python<'_>) -> PyResult<Py<PyString>> {
173-
match &self.name {
174-
Some(name) => Ok(name.clone_ref(py)),
175-
None => Err(PyAttributeError::new_err("__name__")),
176-
}
187+
fn __name__(&self, py: Python<'_>) -> Py<PyString> {
188+
self.name.clone_ref(py)
177189
}
178190

179191
#[getter]
180192
fn __qualname__(&self, py: Python<'_>) -> PyResult<Py<PyString>> {
181-
match (&self.name, &self.qualname_prefix) {
182-
(Some(name), Some(prefix)) => Ok(format!("{}.{}", prefix, name.as_ref(py).to_str()?)
193+
Ok(match &self.qualname_prefix {
194+
Some(prefix) => format!("{}.{}", prefix, self.name.as_ref(py).to_str()?)
183195
.as_str()
184-
.into_py(py)),
185-
(Some(name), None) => Ok(name.clone_ref(py)),
186-
(None, _) => Err(PyAttributeError::new_err("__qualname__")),
187-
}
196+
.into_py(py),
197+
None => self.name.clone_ref(py),
198+
})
188199
}
189200

190201
fn send(&mut self, py: Python<'_>, value: PyObject) -> PyResult<PyObject> {

src/coroutine/cancel.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ impl CancelHandle {
5252
Cancelled(self).await
5353
}
5454

55-
#[doc(hidden)]
55+
/// Instantiate a [`ThrowCallback`] associated to this cancel handle.
5656
pub fn throw_callback(&self) -> ThrowCallback {
5757
ThrowCallback(self.0.clone())
5858
}
@@ -68,7 +68,7 @@ impl Future for Cancelled<'_> {
6868
}
6969
}
7070

71-
#[doc(hidden)]
71+
/// Callback for coroutine `throw` method, notifying the associated [`CancelHandle`]
7272
pub struct ThrowCallback(Arc<Mutex<Inner>>);
7373

7474
impl ThrowCallback {

src/impl_/coroutine.rs

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,9 @@
11
use std::{
2-
future::Future,
32
mem,
43
ops::{Deref, DerefMut},
54
};
65

7-
use crate::{
8-
coroutine::{cancel::ThrowCallback, Coroutine},
9-
pyclass::boolean_struct::False,
10-
types::PyString,
11-
IntoPy, Py, PyAny, PyCell, PyClass, PyErr, PyObject, PyResult, Python,
12-
};
13-
14-
pub fn new_coroutine<F, T, E>(
15-
name: &PyString,
16-
qualname_prefix: Option<&'static str>,
17-
throw_callback: Option<ThrowCallback>,
18-
allow_threads: bool,
19-
future: F,
20-
) -> Coroutine
21-
where
22-
F: Future<Output = Result<T, E>> + Send + 'static,
23-
T: IntoPy<PyObject> + Send,
24-
E: Into<PyErr> + Send,
25-
{
26-
Coroutine::new(
27-
Some(name.into()),
28-
qualname_prefix,
29-
throw_callback,
30-
allow_threads,
31-
future,
32-
)
33-
}
6+
use crate::{pyclass::boolean_struct::False, Py, PyAny, PyCell, PyClass, PyResult, Python};
347

358
fn get_ptr<T: PyClass>(obj: &Py<T>) -> *mut T {
369
// SAFETY: Py<T> can be casted as *const PyCell<T>

tests/test_coroutine.rs

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
#![cfg(feature = "macros")]
22
#![cfg(not(target_arch = "wasm32"))]
3-
use std::{ops::Deref, task::Poll, thread, time::Duration};
3+
use std::{ops::Deref, sync::Arc, task::Poll, thread, time::Duration};
44

55
use futures::{channel::oneshot, future::poll_fn, FutureExt};
66
use pyo3::{
7-
coroutine::CancelHandle,
7+
coroutine::{CancelHandle, Coroutine},
8+
intern,
89
prelude::*,
910
py_run,
11+
sync::GILOnceCell,
1012
types::{IntoPyDict, PyType},
1113
};
1214

@@ -235,7 +237,7 @@ fn test_async_method_receiver() {
235237
Python::with_gil(|gil| {
236238
let test = r#"
237239
import asyncio
238-
240+
239241
obj = Counter()
240242
coro1 = obj.get()
241243
coro2 = obj.get()
@@ -266,3 +268,54 @@ fn test_async_method_receiver() {
266268
py_run!(gil, *locals, &common::asyncio_windows(test));
267269
})
268270
}
271+
272+
#[test]
273+
fn multi_thread_event_loop() {
274+
Python::with_gil(|gil| {
275+
let sleep = wrap_pyfunction!(sleep, gil).unwrap();
276+
let test = r#"
277+
import asyncio
278+
import threading
279+
loop = asyncio.new_event_loop()
280+
# spawn the sleep task and run just one iteration of the event loop
281+
# to schedule the sleep wakeup
282+
task = loop.create_task(sleep(0.1))
283+
loop.stop()
284+
loop.run_forever()
285+
assert not task.done()
286+
# spawn a thread to complete the execution of the sleep task
287+
def target(loop, task):
288+
loop.run_until_complete(task)
289+
thread = threading.Thread(target=target, args=(loop, task))
290+
thread.start()
291+
thread.join()
292+
assert task.result() == 42
293+
"#;
294+
py_run!(gil, sleep, &common::asyncio_windows(test));
295+
})
296+
}
297+
298+
#[test]
299+
fn closed_event_loop() {
300+
let waker = Arc::new(GILOnceCell::new());
301+
let waker2 = waker.clone();
302+
let future = poll_fn(move |cx| {
303+
Python::with_gil(|gil| waker2.set(gil, cx.waker().clone()).unwrap());
304+
Poll::Pending::<PyResult<()>>
305+
});
306+
Python::with_gil(|gil| {
307+
let register_waker = Coroutine::new(intern!(gil, "register_waker"), future).into_py(gil);
308+
let test = r#"
309+
import asyncio
310+
loop = asyncio.new_event_loop()
311+
# register a waker by spawning a task and polling it once, then close the loop
312+
task = loop.create_task(register_waker)
313+
loop.stop()
314+
loop.run_forever()
315+
loop.close()
316+
"#;
317+
py_run!(gil, register_waker, &common::asyncio_windows(test));
318+
// asyncio waker can be used even if the event loop is closed
319+
Python::with_gil(|gil| waker.get(gil).unwrap().wake_by_ref())
320+
})
321+
}

0 commit comments

Comments
 (0)