Skip to content

Commit 432e2c4

Browse files
committed
Address latest comments
1 parent 9dda661 commit 432e2c4

File tree

5 files changed

+70
-66
lines changed

5 files changed

+70
-66
lines changed

guide/src/class.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,16 +1209,16 @@ Python::with_gil(|py| {
12091209
12101210
assert isinstance(square, cls)
12111211
assert isinstance(square, cls.RegularPolygon)
1212-
assert square._0 == 4
1213-
assert square._1 == 10.0
1212+
assert square[0] == 4 # Gets _0 field
1213+
assert square[1] == 10.0 # Gets _1 field
12141214
12151215
def count_vertices(cls, shape):
12161216
match shape:
12171217
case cls.Circle():
12181218
return 0
12191219
case cls.Rectangle():
12201220
return 4
1221-
case cls.RegularPolygon(_0=n):
1221+
case cls.RegularPolygon(n):
12221222
return n
12231223
case cls.Nothing():
12241224
return 0

pyo3-macros-backend/src/pyclass.rs

Lines changed: 56 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -975,19 +975,19 @@ fn impl_complex_enum(
975975
Ok(quote! {
976976
#pytypeinfo
977977

978-
#pyclass_impls
978+
#pyclass_impls
979979

980-
#[doc(hidden)]
981-
#[allow(non_snake_case)]
982-
impl #cls {}
980+
#[doc(hidden)]
981+
#[allow(non_snake_case)]
982+
impl #cls {}
983983

984-
#(#variant_cls_zsts)*
984+
#(#variant_cls_zsts)*
985985

986-
#(#variant_cls_pytypeinfos)*
986+
#(#variant_cls_pytypeinfos)*
987987

988-
#(#variant_cls_pyclass_impls)*
988+
#(#variant_cls_pyclass_impls)*
989989

990-
#(#variant_cls_impls)*
990+
#(#variant_cls_impls)*
991991
})
992992
}
993993

@@ -1006,6 +1006,36 @@ fn impl_complex_enum_variant_cls(
10061006
}
10071007
}
10081008

1009+
fn impl_complex_enum_variant_match_args(
1010+
ctx: &Ctx,
1011+
variant_cls_type: &syn::Type,
1012+
field_names: &mut Vec<Ident>,
1013+
) -> (MethodAndMethodDef, syn::ImplItemConst) {
1014+
let match_args_const_impl: syn::ImplItemConst = {
1015+
let args_tp = field_names.iter().map(|_| {
1016+
quote! { &'static str }
1017+
});
1018+
parse_quote! {
1019+
const __match_args__: ( #(#args_tp,)* ) = (
1020+
#(stringify!(#field_names),)*
1021+
);
1022+
}
1023+
};
1024+
1025+
let spec = ConstSpec {
1026+
rust_ident: format_ident!("__match_args__"),
1027+
attributes: ConstAttributes {
1028+
is_class_attr: true,
1029+
name: None,
1030+
deprecations: Deprecations::new(ctx),
1031+
},
1032+
};
1033+
1034+
let variant_match_args = gen_py_const(variant_cls_type, &spec, ctx);
1035+
1036+
(variant_match_args, match_args_const_impl)
1037+
}
1038+
10091039
fn impl_complex_enum_struct_variant_cls(
10101040
enum_name: &syn::Ident,
10111041
variant: &PyClassEnumStructVariant<'_>,
@@ -1043,6 +1073,11 @@ fn impl_complex_enum_struct_variant_cls(
10431073
field_getter_impls.push(field_getter_impl);
10441074
}
10451075

1076+
let (variant_match_args, match_args_const_impl) =
1077+
impl_complex_enum_variant_match_args(ctx, &variant_cls_type, &mut field_names);
1078+
1079+
field_getters.push(variant_match_args);
1080+
10461081
let cls_impl = quote! {
10471082
#[doc(hidden)]
10481083
#[allow(non_snake_case)]
@@ -1052,6 +1087,8 @@ fn impl_complex_enum_struct_variant_cls(
10521087
#pyo3_path::PyClassInitializer::from(base_value).add_subclass(#variant_cls)
10531088
}
10541089

1090+
#match_args_const_impl
1091+
10551092
#(#field_getter_impls)*
10561093
}
10571094
};
@@ -1171,52 +1208,6 @@ fn impl_complex_enum_tuple_variant_getitem(
11711208
Ok((variant_getitem, get_item_method_impl))
11721209
}
11731210

1174-
fn impl_complex_enum_tuple_variant_match_args(
1175-
ctx: &Ctx,
1176-
variant_cls_type: &syn::Type,
1177-
field_names: &mut Vec<Ident>,
1178-
) -> (MethodAndMethodDef, syn::ImplItemConst) {
1179-
let match_args_const_impl: syn::ImplItemConst = match field_names.len() {
1180-
// This covers the case where the tuple variant has no fields (valid Rust)
1181-
0 => parse_quote! {
1182-
const __match_args__: () = ();
1183-
},
1184-
1 => {
1185-
let ident = &field_names[0];
1186-
// We need the trailing comma to make it a tuple
1187-
parse_quote! {
1188-
const __match_args__: (&'static str ,) = (stringify!(#ident) , );
1189-
}
1190-
}
1191-
_ => {
1192-
let args_tp = field_names.iter().map(|_| {
1193-
quote! { &'static str }
1194-
});
1195-
parse_quote! {
1196-
const __match_args__: ( #(#args_tp),* ) = (
1197-
#(stringify!(#field_names),)*
1198-
);
1199-
}
1200-
}
1201-
};
1202-
1203-
let spec = ConstSpec {
1204-
rust_ident: format_ident!("__match_args__"),
1205-
attributes: ConstAttributes {
1206-
is_class_attr: true,
1207-
name: Some(NameAttribute {
1208-
kw: syn::parse_quote! { name },
1209-
value: NameLitStr(format_ident!("__match_args__")),
1210-
}),
1211-
deprecations: Deprecations::new(ctx),
1212-
},
1213-
};
1214-
1215-
let variant_match_args = gen_py_const(variant_cls_type, &spec, ctx);
1216-
1217-
(variant_match_args, match_args_const_impl)
1218-
}
1219-
12201211
fn impl_complex_enum_tuple_variant_cls(
12211212
enum_name: &syn::Ident,
12221213
variant: &PyClassEnumTupleVariant<'_>,
@@ -1256,7 +1247,7 @@ fn impl_complex_enum_tuple_variant_cls(
12561247
slots.push(variant_getitem);
12571248

12581249
let (variant_match_args, match_args_method_impl) =
1259-
impl_complex_enum_tuple_variant_match_args(ctx, &variant_cls_type, &mut field_names);
1250+
impl_complex_enum_variant_match_args(ctx, &variant_cls_type, &mut field_names);
12601251

12611252
field_getters.push(variant_match_args);
12621253

@@ -1477,10 +1468,6 @@ fn complex_enum_tuple_variant_new<'a>(
14771468
let arg_py_type: syn::Type = parse_quote!(#pyo3_path::Python<'_>);
14781469

14791470
let args = {
1480-
let mut no_pyo3_attrs = vec![];
1481-
let _attrs =
1482-
crate::pyfunction::PyFunctionArgPyO3Attributes::from_attrs(&mut no_pyo3_attrs)?;
1483-
14841471
let mut args = vec![FnArg::Py(PyArg {
14851472
name: &arg_py_ident,
14861473
ty: &arg_py_type,
@@ -1497,7 +1484,16 @@ fn complex_enum_tuple_variant_new<'a>(
14971484
}
14981485
args
14991486
};
1500-
let signature = crate::pyfunction::FunctionSignature::from_arguments(args)?;
1487+
1488+
let signature = if let Some(constructor) = variant.options.constructor {
1489+
crate::pyfunction::FunctionSignature::from_arguments_and_attribute(
1490+
args,
1491+
constructor.into_signature(),
1492+
)?
1493+
} else {
1494+
crate::pyfunction::FunctionSignature::from_arguments(args)?
1495+
};
1496+
15011497
let spec = FnSpec {
15021498
tp: crate::method::FnType::FnNew,
15031499
name: &format_ident!("__pymethod_constructor__"),

pytests/src/enums.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,16 @@ enum SimpleTupleEnum {
9393

9494
#[pyclass]
9595
pub enum TupleEnum {
96+
#[pyo3(constructor = (_0 = 1, _1 = 1.0, _2 = true))]
97+
FullWithDefault(i32, f64, bool),
9698
Full(i32, f64, bool),
9799
EmptyTuple(),
98100
}
99101

100102
#[pyfunction]
101103
pub fn do_tuple_stuff(thing: &TupleEnum) -> TupleEnum {
102104
match thing {
105+
TupleEnum::FullWithDefault(a, b, c) => TupleEnum::FullWithDefault(*a, *b, *c),
103106
TupleEnum::Full(a, b, c) => TupleEnum::Full(*a, *b, *c),
104107
TupleEnum::EmptyTuple() => TupleEnum::EmptyTuple(),
105108
}

pytests/tests/test_enums.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def test_tuple_enum_variant_constructors():
150150
@pytest.mark.parametrize(
151151
"variant",
152152
[
153+
enums.TupleEnum.FullWithDefault(),
153154
enums.TupleEnum.Full(42, 3.14, False),
154155
enums.TupleEnum.EmptyTuple(),
155156
],
@@ -158,6 +159,13 @@ def test_tuple_enum_variant_subclasses(variant: enums.TupleEnum):
158159
assert isinstance(variant, enums.TupleEnum)
159160

160161

162+
def test_tuple_enum_defaults():
163+
variant = enums.TupleEnum.FullWithDefault()
164+
assert variant._0 == 1
165+
assert variant._1 == 1.0
166+
assert variant._2 is True
167+
168+
161169
def test_tuple_enum_field_getters():
162170
tuple_variant = enums.TupleEnum.Full(42, 3.14, False)
163171
assert tuple_variant._0 == 42

pytests/tests/test_enums_match.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,6 @@ def test_complex_enum_pyfunction_in_out(variant: enums.ComplexEnum):
6565
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
6666
],
6767
)
68-
@pytest.mark.skip(
69-
reason="__match_args__ is not supported for struct enums yet. TODO : Open an issue"
70-
)
7168
def test_complex_enum_partial_match(variant: enums.ComplexEnum):
7269
match variant:
7370
case enums.ComplexEnum.MultiFieldStruct(a):

0 commit comments

Comments
 (0)