Skip to content

Introspection: implement output type #5208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions newsfragments/5208.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Introspection and sub generation: add basic return type support
13 changes: 11 additions & 2 deletions pyo3-introspection/src/introspection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ fn convert_members(
arguments,
parent: _,
decorators,
} => functions.push(convert_function(name, arguments, decorators)),
returns,
} => functions.push(convert_function(name, arguments, decorators, returns)),
}
}
Ok((modules, classes, functions))
Expand Down Expand Up @@ -170,7 +171,12 @@ fn convert_class(
})
}

fn convert_function(name: &str, arguments: &ChunkArguments, decorators: &[String]) -> Function {
fn convert_function(
name: &str,
arguments: &ChunkArguments,
decorators: &[String],
returns: &Option<String>,
) -> Function {
Function {
name: name.into(),
decorators: decorators.to_vec(),
Expand All @@ -187,6 +193,7 @@ fn convert_function(name: &str, arguments: &ChunkArguments, decorators: &[String
.as_ref()
.map(convert_variable_length_argument),
},
returns: returns.clone(),
}
}

Expand Down Expand Up @@ -382,6 +389,8 @@ enum Chunk {
parent: Option<String>,
#[serde(default)]
decorators: Vec<String>,
#[serde(default)]
returns: Option<String>,
},
}

Expand Down
2 changes: 2 additions & 0 deletions pyo3-introspection/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ pub struct Function {
/// decorator like 'property' or 'staticmethod'
pub decorators: Vec<String>,
pub arguments: Arguments,
/// return type
pub returns: Option<String>,
}

#[derive(Debug, Eq, PartialEq, Clone, Hash)]
Expand Down
33 changes: 22 additions & 11 deletions pyo3-introspection/src/stubs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,22 @@ fn function_stubs(function: &Function, modules_to_import: &mut BTreeSet<String>)
if let Some(argument) = &function.arguments.kwarg {
parameters.push(format!("**{}", variable_length_argument_stub(argument)));
}
let output = format!("def {}({}): ...", function.name, parameters.join(", "));
if function.decorators.is_empty() {
return output;
}
let mut buffer = String::new();
for decorator in &function.decorators {
buffer.push('@');
buffer.push_str(decorator);
buffer.push('\n');
}
buffer.push_str(&output);
buffer.push_str("def ");
buffer.push_str(&function.name);
buffer.push('(');
buffer.push_str(&parameters.join(", "));
buffer.push(')');
if let Some(returns) = &function.returns {
buffer.push_str(" -> ");
buffer.push_str(annotation_stub(returns, modules_to_import));
}
buffer.push_str(": ...");
buffer
}

Expand All @@ -132,11 +137,7 @@ fn argument_stub(argument: &Argument, modules_to_import: &mut BTreeSet<String>)
let mut output = argument.name.clone();
if let Some(annotation) = &argument.annotation {
output.push_str(": ");
output.push_str(annotation);
if let Some((module, _)) = annotation.rsplit_once('.') {
// TODO: this is very naive
modules_to_import.insert(module.into());
}
output.push_str(annotation_stub(annotation, modules_to_import));
}
if let Some(default_value) = &argument.default_value {
output.push_str(if argument.annotation.is_some() {
Expand All @@ -153,6 +154,14 @@ fn variable_length_argument_stub(argument: &VariableLengthArgument) -> String {
argument.name.clone()
}

fn annotation_stub<'a>(annotation: &'a str, modules_to_import: &mut BTreeSet<String>) -> &'a str {
if let Some((module, _)) = annotation.rsplit_once('.') {
// TODO: this is very naive
modules_to_import.insert(module.into());
}
annotation
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -186,9 +195,10 @@ mod tests {
name: "kwarg".into(),
}),
},
returns: Some("list[str]".into()),
};
assert_eq!(
"def func(posonly, /, arg, *varargs, karg: str, **kwarg): ...",
"def func(posonly, /, arg, *varargs, karg: str, **kwarg) -> list[str]: ...",
function_stubs(&function, &mut BTreeSet::new())
)
}
Expand Down Expand Up @@ -217,6 +227,7 @@ mod tests {
}],
kwarg: None,
},
returns: None,
};
assert_eq!(
"def afunc(posonly=1, /, arg=True, *, karg: str = \"foo\"): ...",
Expand Down
29 changes: 28 additions & 1 deletion pyo3-macros-backend/src/introspection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::mem::take;
use std::sync::atomic::{AtomicUsize, Ordering};
use syn::ext::IdentExt;
use syn::visit_mut::{visit_type_mut, VisitMut};
use syn::{Attribute, Ident, Type, TypePath};
use syn::{Attribute, Ident, ReturnType, Type, TypePath};

static GLOBAL_COUNTER_FOR_UNIQUE_NAMES: AtomicUsize = AtomicUsize::new(0);

Expand Down Expand Up @@ -99,12 +99,14 @@ pub fn class_introspection_code(
.emit(pyo3_crate_path)
}

#[allow(clippy::too_many_arguments)]
pub fn function_introspection_code(
pyo3_crate_path: &PyO3CratePath,
ident: Option<&Ident>,
name: &str,
signature: &FunctionSignature<'_>,
first_argument: Option<&'static str>,
returns: ReturnType,
decorators: impl IntoIterator<Item = String>,
parent: Option<&Type>,
) -> TokenStream {
Expand All @@ -115,6 +117,25 @@ pub fn function_introspection_code(
"arguments",
arguments_introspection_data(signature, first_argument, parent),
),
(
"returns",
match returns {
ReturnType::Default => IntrospectionNode::String("None".into()),
ReturnType::Type(_, ty) => match *ty {
Type::Tuple(t) if t.elems.is_empty() => {
// () is converted to None in return types
IntrospectionNode::String("None".into())
}
mut ty => {
if let Some(class_type) = parent {
replace_self(&mut ty, class_type);
}
ty = ty.elide_lifetimes();
IntrospectionNode::OutputType { rust_type: ty }
}
},
},
),
]);
if let Some(ident) = ident {
desc.insert(
Expand Down Expand Up @@ -290,6 +311,7 @@ enum IntrospectionNode<'a> {
String(Cow<'a, str>),
IntrospectionId(Option<Cow<'a, Type>>),
InputType { rust_type: Type, nullable: bool },
OutputType { rust_type: Type },
Map(HashMap<&'static str, IntrospectionNode<'a>>),
List(Vec<IntrospectionNode<'a>>),
}
Expand Down Expand Up @@ -340,6 +362,11 @@ impl IntrospectionNode<'_> {
}
content.push_str("\"");
}
Self::OutputType { rust_type } => {
content.push_str("\"");
content.push_tokens(quote! { <#rust_type as #pyo3_crate_path::impl_::introspection::PyReturnType>::OUTPUT_TYPE });
content.push_str("\"");
}
Self::Map(map) => {
content.push_str("{");
for (i, (key, value)) in map.into_iter().enumerate() {
Expand Down
4 changes: 4 additions & 0 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,8 @@ pub struct FnSpec<'a> {
pub asyncness: Option<syn::Token![async]>,
pub unsafety: Option<syn::Token![unsafe]>,
pub warnings: Vec<PyFunctionWarning>,
#[cfg(feature = "experimental-inspect")]
pub output: syn::ReturnType,
}

pub fn parse_method_receiver(arg: &syn::FnArg) -> Result<SelfType> {
Expand Down Expand Up @@ -526,6 +528,8 @@ impl<'a> FnSpec<'a> {
asyncness: sig.asyncness,
unsafety: sig.unsafety,
warnings,
#[cfg(feature = "experimental-inspect")]
output: sig.output.clone(),
})
}

Expand Down
6 changes: 6 additions & 0 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1670,6 +1670,8 @@ fn complex_enum_struct_variant_new<'a>(
asyncness: None,
unsafety: None,
warnings: vec![],
#[cfg(feature = "experimental-inspect")]
output: syn::ReturnType::Default,
};

crate::pymethod::impl_py_method_def_new(&variant_cls_type, &spec, ctx)
Expand Down Expand Up @@ -1725,6 +1727,8 @@ fn complex_enum_tuple_variant_new<'a>(
asyncness: None,
unsafety: None,
warnings: vec![],
#[cfg(feature = "experimental-inspect")]
output: syn::ReturnType::Default,
};

crate::pymethod::impl_py_method_def_new(&variant_cls_type, &spec, ctx)
Expand All @@ -1750,6 +1754,8 @@ fn complex_enum_variant_field_getter<'a>(
asyncness: None,
unsafety: None,
warnings: vec![],
#[cfg(feature = "experimental-inspect")]
output: syn::ReturnType::Type(Token![->](field_span), Box::new(variant_cls_type.clone())),
};

let property_type = crate::pymethod::PropertyType::Function {
Expand Down
3 changes: 3 additions & 0 deletions pyo3-macros-backend/src/pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ pub fn impl_wrap_pyfunction(
&name.to_string(),
&signature,
None,
func.sig.output.clone(),
[] as [String; 0],
None,
);
Expand All @@ -410,6 +411,8 @@ pub fn impl_wrap_pyfunction(
asyncness: func.sig.asyncness,
unsafety: func.sig.unsafety,
warnings,
#[cfg(feature = "experimental-inspect")]
output: func.sig.output.clone(),
};

let wrapper_ident = format_ident!("__pyfunction_{}", spec.name);
Expand Down
3 changes: 3 additions & 0 deletions pyo3-macros-backend/src/pyimpl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ fn method_introspection_code(spec: &FnSpec<'_>, parent: &syn::Type, ctx: &Ctx) -

// We introduce self/cls argument and setup decorators
let mut first_argument = None;
let mut output = spec.output.clone();
let mut decorators = Vec::new();
match &spec.tp {
FnType::Getter(_) => {
Expand All @@ -382,6 +383,7 @@ fn method_introspection_code(spec: &FnSpec<'_>, parent: &syn::Type, ctx: &Ctx) -
}
FnType::FnNew | FnType::FnNewClass(_) => {
first_argument = Some("cls");
output = syn::ReturnType::Default; // The __new__ Python function return type is None
}
FnType::FnClass(_) => {
first_argument = Some("cls");
Expand All @@ -404,6 +406,7 @@ fn method_introspection_code(spec: &FnSpec<'_>, parent: &syn::Type, ctx: &Ctx) -
&name,
&spec.signature,
first_argument,
output,
decorators,
Some(parent),
)
Expand Down
28 changes: 14 additions & 14 deletions pytests/stubs/pyclasses.pyi
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
import typing

class AssertingBaseClass:
def __new__(cls, /, expected_type: typing.Any): ...
def __new__(cls, /, expected_type: typing.Any) -> None: ...

class ClassWithDecorators:
def __new__(cls, /): ...
def __new__(cls, /) -> None: ...
@property
def attr(self, /): ...
def attr(self, /) -> int: ...
@attr.setter
def attr(self, /, value: int): ...
def attr(self, /, value: int) -> None: ...
@classmethod
@property
def cls_attribute(cls, /): ...
def cls_attribute(cls, /) -> int: ...
@classmethod
def cls_method(cls, /): ...
def cls_method(cls, /) -> int: ...
@staticmethod
def static_method(): ...
def static_method() -> int: ...

class ClassWithoutConstructor: ...

class EmptyClass:
def __len__(self, /): ...
def __new__(cls, /): ...
def method(self, /): ...
def __len__(self, /) -> int: ...
def __new__(cls, /) -> None: ...
def method(self, /) -> None: ...

class PyClassIter:
def __new__(cls, /): ...
def __next__(self, /): ...
def __new__(cls, /) -> None: ...
def __next__(self, /) -> int: ...

class PyClassThreadIter:
def __new__(cls, /): ...
def __next__(self, /): ...
def __new__(cls, /) -> None: ...
def __next__(self, /) -> int: ...
18 changes: 10 additions & 8 deletions pytests/stubs/pyfunctions.pyi
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
import typing

def args_kwargs(*args, **kwargs): ...
def none(): ...
def positional_only(a: typing.Any, /, b: typing.Any): ...
def args_kwargs(*args, **kwargs) -> typing.Any: ...
def none() -> None: ...
def positional_only(a: typing.Any, /, b: typing.Any) -> typing.Any: ...
def simple(
a: typing.Any, b: typing.Any | None = None, *, c: typing.Any | None = None
): ...
) -> typing.Any: ...
def simple_args(
a: typing.Any, b: typing.Any | None = None, *args, c: typing.Any | None = None
): ...
) -> typing.Any: ...
def simple_args_kwargs(
a: typing.Any,
b: typing.Any | None = None,
*args,
c: typing.Any | None = None,
**kwargs,
): ...
) -> typing.Any: ...
def simple_kwargs(
a: typing.Any, b: typing.Any | None = None, c: typing.Any | None = None, **kwargs
): ...
def with_typed_args(a: bool = False, b: int = 0, c: float = 0.0, d: str = ""): ...
) -> typing.Any: ...
def with_typed_args(
a: bool = False, b: int = 0, c: float = 0.0, d: str = ""
) -> typing.Any: ...
10 changes: 10 additions & 0 deletions src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ pub trait IntoPyObject<'py>: Sized {
/// The type returned in the event of a conversion error.
type Error: Into<PyErr>;

/// Extracts the type hint information for this type when it appears as a return value.
///
/// For example, `Vec<u32>` would return `List[int]`.
/// The default implementation returns `Any`, which is correct for any type.
///
/// For most types, the return value for this method will be identical to that of [`FromPyObject::INPUT_TYPE`].
/// It may be different for some types, such as `Dict`, to allow duck-typing: functions return `Dict` but take `Mapping` as argument.
#[cfg(feature = "experimental-inspect")]
const OUTPUT_TYPE: &'static str = "typing.Any";

/// Performs the conversion.
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error>;

Expand Down
Loading
Loading