diff --git a/num_enum_derive/src/enum_attributes.rs b/num_enum_derive/src/enum_attributes.rs index fd19a31..f005785 100644 --- a/num_enum_derive/src/enum_attributes.rs +++ b/num_enum_derive/src/enum_attributes.rs @@ -9,24 +9,29 @@ mod kw { syn::custom_keyword!(constructor); syn::custom_keyword!(error_type); syn::custom_keyword!(name); + syn::custom_keyword!(from_primitive); + syn::custom_keyword!(no_panic); } // Example: error_type(name = Foo, constructor = Foo::new) #[cfg_attr(test, derive(Debug))] pub(crate) struct Attributes { pub(crate) error_type: Option, + pub(crate) from_primitive: Option, } // Example: error_type(name = Foo, constructor = Foo::new) #[cfg_attr(test, derive(Debug))] pub(crate) enum AttributeItem { ErrorType(ErrorTypeAttribute), + FromPrimitive(FromPrimitiveAttribute), } impl Parse for Attributes { fn parse(input: ParseStream<'_>) -> Result { let attribute_items = input.parse_terminated(AttributeItem::parse, syn::Token![,])?; let mut maybe_error_type = None; + let mut maybe_from_primitive = None; for attribute_item in &attribute_items { match attribute_item { AttributeItem::ErrorType(error_type) => { @@ -38,10 +43,20 @@ impl Parse for Attributes { } maybe_error_type = Some(error_type.clone()); } + AttributeItem::FromPrimitive(from_primitive) => { + if maybe_from_primitive.is_some() { + return Err(Error::new( + from_primitive.span, + "num_enum attribute must have at most one from_primitive", + )); + } + maybe_from_primitive = Some(from_primitive.clone()); + } } } Ok(Self { error_type: maybe_error_type, + from_primitive: maybe_from_primitive, }) } } @@ -51,6 +66,8 @@ impl Parse for AttributeItem { let lookahead = input.lookahead1(); if lookahead.peek(kw::error_type) { input.parse().map(Self::ErrorType) + } else if lookahead.peek(kw::from_primitive) { + input.parse().map(Self::FromPrimitive) } else { Err(lookahead.error()) } @@ -168,6 +185,74 @@ impl Parse for ErrorTypeConstructorAttribute { } } +#[derive(Clone)] +#[cfg_attr(test, derive(Debug))] +pub(crate) struct FromPrimitiveAttribute { + pub(crate) no_panic: Option, + + span: Span, +} + +impl Parse for FromPrimitiveAttribute { + fn parse(input: ParseStream) -> Result { + let keyword: kw::from_primitive = input.parse()?; + let span = keyword.span; + let content; + syn::parenthesized!(content in input); + let attribute_values = + content.parse_terminated(FromPrimitiveNamedArgument::parse, syn::Token![,])?; + let mut no_panic = None; + for attribute_value in &attribute_values { + match attribute_value { + FromPrimitiveNamedArgument::NoPanic(no_panic_attr) => { + if no_panic.is_some() { + die!("num_enum from_primitive attribute must have exactly one `no_panic` value"); + } + no_panic = Some(no_panic_attr.clone()); + } + } + } + + Ok(Self { no_panic, span }) + } +} + +pub(crate) enum FromPrimitiveNamedArgument { + NoPanic(FromPrimitiveNoPanicAttribute), +} + +impl Parse for FromPrimitiveNamedArgument { + fn parse(input: ParseStream<'_>) -> Result { + let lookahead = input.lookahead1(); + if lookahead.peek(kw::no_panic) { + input.parse().map(Self::NoPanic) + } else { + Err(lookahead.error()) + } + } +} + +#[derive(Clone)] +#[cfg_attr(test, derive(Debug))] +pub(crate) struct FromPrimitiveNoPanicAttribute { + pub(crate) no_panic: Option, +} + +impl Parse for FromPrimitiveNoPanicAttribute { + fn parse(input: ParseStream) -> Result { + input.parse::()?; + + let no_panic: Option = if input.peek(syn::Token![=]) { + input.parse::()?; + Some(input.parse::()?.value) + } else { + None + }; + + Ok(Self { no_panic }) + } +} + #[cfg(test)] mod test { use crate::enum_attributes::Attributes; diff --git a/num_enum_derive/src/lib.rs b/num_enum_derive/src/lib.rs index 4fc4927..9a73f60 100644 --- a/num_enum_derive/src/lib.rs +++ b/num_enum_derive/src/lib.rs @@ -121,12 +121,19 @@ pub fn derive_from_primitive(input: TokenStream) -> TokenStream { let expression_idents: Vec> = enum_info.expression_idents(); let variant_expressions: Vec> = enum_info.variant_expressions(); + let no_panic_attr = if enum_info.no_panic { + quote!(#[no_panic::no_panic]) + } else { + quote!() + }; + debug_assert_eq!(variant_idents.len(), variant_expressions.len()); TokenStream::from(quote! { impl ::#krate::FromPrimitive for #name { type Primitive = #repr; + #no_panic_attr fn from_primitive(number: Self::Primitive) -> Self { // Use intermediate const(s) so that enums defined like // `Two = ONE + 1u8` work properly. diff --git a/num_enum_derive/src/parsing.rs b/num_enum_derive/src/parsing.rs index 1a24b0c..bcd4802 100644 --- a/num_enum_derive/src/parsing.rs +++ b/num_enum_derive/src/parsing.rs @@ -1,4 +1,4 @@ -use crate::enum_attributes::ErrorTypeAttribute; +use crate::enum_attributes::{ErrorTypeAttribute, FromPrimitiveAttribute}; use crate::utils::die; use crate::variant_attributes::{NumEnumVariantAttributeItem, NumEnumVariantAttributes}; use proc_macro2::Span; @@ -15,6 +15,7 @@ pub(crate) struct EnumInfo { pub(crate) repr: Ident, pub(crate) variants: Vec, pub(crate) error_type_info: ErrorType, + pub(crate) no_panic: bool, } impl EnumInfo { @@ -89,9 +90,10 @@ impl EnumInfo { fn parse_attrs>( attrs: Attrs, - ) -> Result<(Ident, Option)> { + ) -> Result<(Ident, Option, Option)> { let mut maybe_repr = None; let mut maybe_error_type = None; + let mut maybe_from_primitive = None; for attr in attrs { if let Meta::List(meta_list) = &attr.meta { if let Some(ident) = meta_list.path.get_ident() { @@ -122,6 +124,13 @@ impl EnumInfo { } maybe_error_type = Some(error_type.into()); } + + if let Some(from_primitive) = attributes.from_primitive { + if maybe_from_primitive.is_some() { + die!(attr => "At most one num_enum from_primitive attribute may be specified"); + } + maybe_from_primitive = Some(from_primitive.into()); + } } } } @@ -129,7 +138,7 @@ impl EnumInfo { if maybe_repr.is_none() { die!("Missing `#[repr({Integer})]` attribute"); } - Ok((maybe_repr.unwrap(), maybe_error_type)) + Ok((maybe_repr.unwrap(), maybe_error_type, maybe_from_primitive)) } } @@ -144,7 +153,8 @@ impl Parse for EnumInfo { Data::Struct(data) => die!(data.struct_token => "Expected enum but found struct"), }; - let (repr, maybe_error_type) = Self::parse_attrs(input.attrs.into_iter())?; + let (repr, maybe_error_type, maybe_from_primitive) = + Self::parse_attrs(input.attrs.into_iter())?; let mut variants: Vec = vec![]; let mut has_default_variant: bool = false; @@ -396,12 +406,21 @@ impl Parse for EnumInfo { }, } }); + + let no_panic = match maybe_from_primitive { + None => false, + Some(from_primitive) => match from_primitive.no_panic { + None => false, + Some(no_panic_attribute) => no_panic_attribute.no_panic.unwrap_or_else(|| true), + }, + }; EnumInfo { name, repr, variants, error_type_info, + no_panic, } }) }