diff --git a/newsfragments/4987.added.md b/newsfragments/4987.added.md new file mode 100644 index 00000000000..a6f21386ad3 --- /dev/null +++ b/newsfragments/4987.added.md @@ -0,0 +1 @@ +Add `IntoPyObject` & `FromPyObject` for `Arc` diff --git a/src/conversions/std/mod.rs b/src/conversions/std/mod.rs index 305344b1284..fe563dc9196 100644 --- a/src/conversions/std/mod.rs +++ b/src/conversions/std/mod.rs @@ -9,5 +9,6 @@ mod path; mod set; mod slice; mod string; +mod sync; mod time; mod vec; diff --git a/src/conversions/std/sync.rs b/src/conversions/std/sync.rs new file mode 100644 index 00000000000..e2fad9856c5 --- /dev/null +++ b/src/conversions/std/sync.rs @@ -0,0 +1,85 @@ +#[cfg(feature = "experimental-inspect")] +use crate::inspect::types::TypeInfo; +use crate::types::PyAnyMethods; +use crate::{Bound, BoundObject, FromPyObject, IntoPyObject, PyAny, PyErr, PyResult, Python}; +use std::sync::Arc; + +// TODO find a better way (without the extra type parameters) to name the associated types in the trait. +impl<'py, A, T, O, E> IntoPyObject<'py> for Arc +where + for<'a> &'a A: IntoPyObject<'py, Target = T, Output = O, Error = E>, + O: BoundObject<'py, T>, + E: Into, +{ + type Target = T; + type Output = O; + type Error = E; + + #[inline] + fn into_pyobject(self, py: Python<'py>) -> Result { + (&*self).into_pyobject(py) + } + + #[cfg(feature = "experimental-inspect")] + fn type_output() -> TypeInfo { + <&A as IntoPyObject<'py>>::type_output() + } +} + +impl<'a, 'py, T: 'a> IntoPyObject<'py> for &'a Arc +where + &'a T: IntoPyObject<'py>, +{ + type Target = <&'a T as IntoPyObject<'py>>::Target; + type Output = <&'a T as IntoPyObject<'py>>::Output; + type Error = <&'a T as IntoPyObject<'py>>::Error; + + #[inline] + fn into_pyobject(self, py: Python<'py>) -> Result { + (&**self).into_pyobject(py) + } + + #[cfg(feature = "experimental-inspect")] + fn type_output() -> TypeInfo { + <&'a T as IntoPyObject<'py>>::type_output() + } +} + +impl<'py, T> FromPyObject<'py> for Arc +where + T: FromPyObject<'py>, +{ + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + ob.extract::().map(Arc::new) + } + + #[cfg(feature = "experimental-inspect")] + fn type_input() -> TypeInfo { + T::type_input() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::PyInt; + use crate::Python; + + #[test] + fn test_arc_into_pyobject() { + macro_rules! test_roundtrip { + ($arc:expr) => { + Python::with_gil(|py| { + let arc = $arc; + let obj: Bound<'_, PyInt> = arc.into_pyobject(py).unwrap(); + assert_eq!(obj.extract::().unwrap(), 42); + let roundtrip = obj.extract::>().unwrap(); + assert_eq!(&42, roundtrip.as_ref()); + }); + }; + } + + test_roundtrip!(Arc::new(42)); + test_roundtrip!(&Arc::new(42)); + } +} diff --git a/tests/test_getter_setter.rs b/tests/test_getter_setter.rs index 82a50442ec5..e2b4570e14f 100644 --- a/tests/test_getter_setter.rs +++ b/tests/test_getter_setter.rs @@ -318,3 +318,24 @@ fn test_optional_setter() { ); }) } + +#[pyclass(get_all)] +struct ArcGetterSetter { + #[pyo3(set)] + foo: std::sync::Arc, +} + +#[test] +fn test_arc_getter_setter() { + Python::with_gil(|py| { + let instance = Py::new( + py, + ArcGetterSetter { + foo: std::sync::Arc::new(42), + }, + ) + .unwrap(); + py_run!(py, instance, "assert instance.foo == 42"); + py_run!(py, instance, "instance.foo = 43; assert instance.foo == 43"); + }) +}