From 5b53eade8717fdeae488d554543b2a72de0a615b Mon Sep 17 00:00:00 2001 From: Roland Fredenhagen Date: Sat, 23 Mar 2024 17:32:45 +0100 Subject: [PATCH] tuple struct support --- CHANGELOG.md | 1 + macro/src/lib.rs | 322 +++++++++++++++++++++++++++++++---------------- src/lib.rs | 36 ++++-- tests/derive.rs | 37 ++++++ 4 files changed, 280 insertions(+), 116 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a5cf399..8872d43 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added - `from_attribute` and `from_attribue_partial` to `FromAttr`. +- Support for tuple structs. ### Fixed - `FromAttr` did not support `#[flag]` or `#[name = value]` attribute styles at the root. diff --git a/macro/src/lib.rs b/macro/src/lib.rs index 2b1901d..47ebdf1 100644 --- a/macro/src/lib.rs +++ b/macro/src/lib.rs @@ -1,16 +1,17 @@ use std::borrow::Cow; use std::collections::{HashMap, HashSet}; +use std::fmt::Display; use std::iter; use collection_literals::hash; use interpolator::{format, Formattable}; -use manyhow::{bail, error_message, manyhow, ErrorMessage, Result}; -use proc_macro2::{Span, TokenStream}; +use manyhow::{bail, error_message, manyhow, span_range, ErrorMessage, Result}; +use proc_macro2::{Literal, Span, TokenStream}; use proc_macro_utils::{TokenParser, TokenStream2Ext}; use quote::{format_ident, ToTokens}; use quote_use::quote_use as quote; use syn::spanned::Spanned; -use syn::{DataStruct, DeriveInput, Field, Fields, FieldsNamed, Generics, Ident, LitStr, Type}; +use syn::{DataStruct, DeriveInput, Field, Fields, Generics, Ident, LitStr, Type}; const ATTRIBUTE_IDENT: &str = "attribute"; @@ -439,8 +440,31 @@ impl Conflicts { } } +enum IdentOrIdx { + Ident(Ident), + Idx(usize), +} + +impl Display for IdentOrIdx { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + IdentOrIdx::Ident(i) => write!(f, "{i}"), + IdentOrIdx::Idx(i) => write!(f, "{i}"), + } + } +} + +impl ToTokens for IdentOrIdx { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + IdentOrIdx::Ident(i) => i.to_tokens(tokens), + IdentOrIdx::Idx(i) => Literal::usize_unsuffixed(*i).to_tokens(tokens), + } + } +} + struct AttrField { - ident: Ident, + ident: IdentOrIdx, duplicate: LitStr, help: Option, missing: String, @@ -451,84 +475,106 @@ struct AttrField { impl AttrField { fn parse_fields( - input: FieldsNamed, + input: impl IntoIterator, struct_error: &StructError, attribute_ident: Option, ) -> Result<(Vec, Conflicts)> { let mut conflicts = Conflicts::default(); Ok(( - input - .named - .into_iter() - .map( - |Field { - attrs, ident, ty, .. - }| { - let ident = ident.expect("named struct fields should have idents"); - - let FieldAttrs { - optional, - default, - conflicts: field_conflicts, - example, - positional, - } = FieldAttrs::from_attrs(attrs)?; - - for conflict in field_conflicts { - conflicts.insert(ident.clone(), conflict); - } + input + .into_iter() + .enumerate() + .map( + |( + idx, + Field { + attrs, ident, ty, .. + }, + )| { + let field_name = ident + .as_ref() + .map_or_else(|| format!("positional_{idx}"), ToString::to_string); + + let FieldAttrs { + optional, + default, + conflicts: field_conflicts, + example, + mut positional, + } = FieldAttrs::from_attrs(attrs)?; + + positional |= ident.is_none(); + + for conflict in field_conflicts { + if positional { + bail!( + ident.map_or_else( + || span_range!(ty), + |ident| ident.span().span_range() + ), + "positional fields do not support conflicts" + ) + } + conflicts.insert( + ident + .as_ref() + .expect("unnamed fields should be positional") + .clone(), + conflict, + ); + } - let duplicate = - struct_error.duplicate_field().format( - hash!("field" => Formattable::display(&ident)) - )?; + let duplicate = struct_error + .duplicate_field() + .format(hash!("field" => Formattable::display(&field_name)))?; - let help = if let Some(help) = struct_error.field_help() { - if attribute_ident.is_none() && help.default { + let help = if let Some(help) = struct_error.field_help() { + if attribute_ident.is_none() && help.default { + None + } else { + let example = &example.as_deref().unwrap_or("..."); + let mut context = hash!( + "field" => Formattable::display(&field_name), + "example" => Formattable::display(example), + "open_or_eq" => Formattable::display(&"{open_or_eq}"), + "close_or_empty" => Formattable::display(&"{close_or_empty}")); + if let Some(ident) = &attribute_ident { + context.insert("attribute", Formattable::display(ident)); + } + Some(format!("\n\n= help: {}", help.format(context)?.value())) + } + } else { None + }; + + let missing = struct_error + .missing_field() + .format(hash!("field" => Formattable::display(&field_name)))? + .value() + + help.as_deref().unwrap_or_default(); + + let default = if let Some(default) = default { + Some(default.to_token_stream()) + } else if optional { + Some(quote!(<#ty as Default>::default())) } else { - let example = &example.as_deref().unwrap_or("..."); - let mut context = hash!("field" => Formattable::display(&ident), "example" => Formattable::display(example), "open_or_eq" => Formattable::display(&"{open_or_eq}"), "close_or_empty" => Formattable::display(&"{close_or_empty}")); - if let Some(ident) = &attribute_ident { - context.insert("attribute", Formattable::display(ident)); - } - Some(format!( - "\n\n= help: {}", - help.format(context)?.value() - )) - } - } else { - None - }; - - let missing = - struct_error.missing_field().format( - hash!("field" => Formattable::display(&ident)), - )?.value() - + help.as_deref().unwrap_or_default(); - - let default = if let Some(default) = default { - Some(default.to_token_stream()) - } else if optional { - Some(quote!(<#ty as Default>::default())) - } else { - None - }; - - Ok(AttrField { - positional, - ident, - duplicate, - help, - missing, - default, - ty - }) - }, - ) - .collect::>()?, - conflicts, - )) + None + }; + + Ok(AttrField { + positional, + ident: ident.map_or(IdentOrIdx::Idx(idx), IdentOrIdx::Ident), + duplicate, + help, + missing, + default, + ty, + }) + }, + ) + .collect::>()?, + conflicts, + )) } fn map_error(&self, ty: &Type) -> TokenStream { @@ -553,11 +599,18 @@ impl AttrField { fn partial(&self) -> TokenStream { let ty = &self.ty; - let ident = &self.ident; - if self.positional { - quote!(#ident: Option<::attribute_derive::parsing::SpannedValue<<#ty as ::attribute_derive::parsing::AttributeBase>::Partial>>) + if let IdentOrIdx::Ident(ref ident) = self.ident { + if self.positional { + quote!(#ident: Option<::attribute_derive::parsing::SpannedValue< + <#ty as ::attribute_derive::parsing::AttributeBase>::Partial>>) + } else { + quote!(#ident: Option<::attribute_derive::parsing::Named< + ::attribute_derive::parsing::SpannedValue< + <#ty as ::attribute_derive::parsing::AttributeBase>::Partial>>>) + } } else { - quote!(#ident: Option<::attribute_derive::parsing::Named<::attribute_derive::parsing::SpannedValue<<#ty as ::attribute_derive::parsing::AttributeBase>::Partial>>>) + quote!(Option<::attribute_derive::parsing::SpannedValue< + <#ty as ::attribute_derive::parsing::AttributeBase>::Partial>>) } } @@ -597,13 +650,19 @@ impl AttrField { # use attribute_derive::parsing::{AttributeBase, AttributeNamed}; # use attribute_derive::from_partial::FromPartial; if let Some(__field) = - <#ty as AttributeNamed>::parse_named(#s_ident, __input) #map_error? { + <#ty as AttributeNamed>::parse_named(#s_ident, __input) #map_error? + { let __name = __field.name; - __partial.#ident = <#ty as FromPartial<<#ty as AttributeBase>::Partial>> ::join( - std::mem::take(&mut __partial.#ident).map(|__v|__v.value), - __field.value, - #duplicate - )?.map(|__v| ::attribute_derive::parsing::Named{name: __name, value: __v}); + __partial.#ident = + <#ty as FromPartial<<#ty as AttributeBase>::Partial>>::join( + std::mem::take(&mut __partial.#ident).map(|__v| __v.value), + __field.value, + #duplicate, + )? + .map(|__v| ::attribute_derive::parsing::Named { + name: __name, + value: __v, + }); #parse_comma; continue; @@ -627,9 +686,11 @@ impl AttrField { let fmt = format!("{missing}{{open_or_eq:.0}}{{close_or_empty:.0}}"); let attr_named = quote!(<#ty as ::attribute_derive::parsing::AttributeNamed>); quote! { - #ident: <#ty as ::attribute_derive::from_partial::FromPartial<<#ty as ::attribute_derive::parsing::AttributeBase>::Partial>>::from_option( + #ident: <#ty as ::attribute_derive::from_partial::FromPartial< + <#ty as ::attribute_derive::parsing::AttributeBase>::Partial>>::from_option( __partial.#ident.map(|__v| __v.value()), - &format!(#fmt, open_or_eq=#attr_named::PREFERRED_OPEN_DELIMITER, close_or_empty=#attr_named::PREFERRED_CLOSE_DELIMITER) + &format!(#fmt, open_or_eq=#attr_named::PREFERRED_OPEN_DELIMITER, + close_or_empty=#attr_named::PREFERRED_CLOSE_DELIMITER) )#unwrap } } @@ -642,7 +703,9 @@ impl AttrField { .. } = self; let join = quote! { - __first.#ident = <#ty as ::attribute_derive::from_partial::FromPartial<<#ty as ::attribute_derive::parsing::AttributeBase>::Partial>>::join(__first_value, __second_value, #duplicate)? + __first.#ident = <#ty as ::attribute_derive::from_partial::FromPartial< + <#ty as ::attribute_derive::parsing::AttributeBase>::Partial>>::join( + __first_value, __second_value, #duplicate)? }; if self.positional { @@ -655,7 +718,8 @@ impl AttrField { } else { quote! { if let Some(__second) = __second.#ident { - let (__first_name, __first_value) = if let Some(__first) = ::std::mem::take(&mut __first.#ident) { + let (__first_name, __first_value) = if let Some(__first) = + ::std::mem::take(&mut __first.#ident) { (Some(__first.name), Some(__first.value)) } else {(None, None)}; let __name = __first_name.unwrap_or(__second.name); @@ -673,18 +737,25 @@ fn parse_comma() -> TokenStream { if __input.is_empty() { return Ok(__partial); } else { - <::attribute_derive::__private::syn::Token![,] as ::attribute_derive::__private::syn::parse::Parse>::parse(__input)?; + <::attribute_derive::__private::syn::Token![,] as + ::attribute_derive::__private::syn::parse::Parse>::parse(__input)?; } } } fn partial_attribute(partial: &Ident, fields: &[AttrField], generics: &Generics) -> Result { + let Some(first_field) = fields.first() else { + return Ok(quote!(#[derive(Default)] struct #partial #generics {})); + }; let fields = fields.iter().map(AttrField::partial); + let fields = if matches!(first_field.ident, IdentOrIdx::Ident(_)) { + quote!({#(#fields),*}) + } else { + quote!((#(#fields),*);) + }; Ok(quote! { #[derive(Default)] - struct #partial #generics { - #(#fields),* - } + struct #partial #generics #fields }) } @@ -735,7 +806,7 @@ pub fn from_attr_derive( let (fields, conflicts) = match data { syn::Data::Struct(DataStruct { - fields: Fields::Named(fields), + fields: fields @ (Fields::Named(_) | Fields::Unnamed(_)), .. }) => AttrField::parse_fields(fields, struct_error, attribute_ident)?, _ => bail!("only works on structs with named fields"), @@ -747,13 +818,52 @@ pub fn from_attr_derive( let error_invalid_name = struct_error.unknown_field_error(&fields)?; - let parse_positionals = fields.iter().filter_map(AttrField::parse_positional); - let parse_named = fields.iter().filter_map(AttrField::parse_named); let partial_fields = fields.iter().map(AttrField::assign_partial); let join_fields = fields.iter().map(AttrField::join_field); + let from_attr = if fields.len() == 1 && fields[0].positional { + // newtype struct + let AttrField { ref ty, .. } = fields[0]; + quote! { + # use ::attribute_derive::FromAttr; + # use ::attribute_derive::parsing::SpannedValue; + # use ::attribute_derive::__private::syn::parse::{ParseStream, Parse}; + # use ::attribute_derive::__private::syn::{Result, Error}; + impl #impl_generics FromAttr for #ident #ty_generics #where_clause { + fn parse_partial(input: ParseStream) -> Result { + <#ty as FromAttr>::parse_partial(input) + .map(SpannedValue::call_site) + .map(Some) + .map(#partial_ident) + } + } + + } + } else { + let parse_positionals: Vec<_> = fields + .iter() + .filter_map(AttrField::parse_positional) + .collect(); + let parse_named: Vec<_> = fields.iter().filter_map(AttrField::parse_named).collect(); + quote! { + # use ::attribute_derive::parsing::AttributeMeta; + # use ::attribute_derive::__private::syn::parse::{ParseStream, Parse}; + # use ::attribute_derive::__private::syn::{Result, Error}; + impl #impl_generics AttributeMeta for #ident #ty_generics #where_clause { + fn parse_inner(__input: ParseStream) -> Result { + let mut __partial = #partial_ident::default(); + #(#parse_positionals)* + while !__input.is_empty() { + #(#parse_named)* + return Err(__input.error(#error_invalid_name)); + } + Ok(__partial) + } + } + } + }; Ok(quote! { - # use ::attribute_derive::parsing::{AttributeBase, AttributeMeta, SpannedValue}; + # use ::attribute_derive::parsing::{AttributeBase, SpannedValue}; # use ::attribute_derive::from_partial::FromPartial; # use ::attribute_derive::__private::syn::parse::{ParseStream, Parse}; # use ::attribute_derive::__private::syn::{Result, Error}; @@ -766,24 +876,16 @@ pub fn from_attr_derive( type Partial = #partial_ident #ty_generics; } - impl #impl_generics AttributeMeta for #ident #ty_generics #where_clause { - fn parse_inner(__input: ParseStream) -> Result { - let mut __partial = #partial_ident::default(); - #(#parse_positionals)* - while !__input.is_empty() { - #(#parse_named)* - return Err(__input.error(#error_invalid_name)); - } - Ok(__partial) - } - } + #from_attr + impl #impl_generics Parse for #ident #ty_generics #where_clause { fn parse(__input: ParseStream) -> Result { ::parse_input(__input) } } - impl #impl_generics FromPartial<#partial_ident #ty_generics> for #ident #ty_generics #where_clause { + impl #impl_generics FromPartial<#partial_ident #ty_generics> for #ident #ty_generics + #where_clause { fn from(__partial: #partial_ident #ty_generics) -> Result { #conflicts Ok(Self { @@ -791,8 +893,10 @@ pub fn from_attr_derive( }) } - fn from_option(__partial: Option<#partial_ident #ty_generics>, _: &str) -> Result { - >::from(__partial.unwrap_or_default()) + fn from_option(__partial: Option<#partial_ident #ty_generics>, _: &str) + -> Result { + >::from( + __partial.unwrap_or_default()) } fn join( diff --git a/src/lib.rs b/src/lib.rs index 01cde5f..7e0118f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -34,20 +34,42 @@ //! //! Any type that for [`AttributeNamed`] or [`AttributePositional`] are //! implemented respectively are supported. These should be the general types -//! that syn supports like [`LitStr`](struct@LitStr) or [`Type`] or that have a -//! direct equivalent in those like [`String`], [`char`] or [`f32`]. A special -//! treatment have [`Vecs`](Vec) which are parsed as either `name = [a, b, c]` -//! or `name(a, b, c)` and [`Options`](Option) that will be [`None`] if not -//! specified and [`Some`] when the value is specified via the attribute. It is -//! not specified via `Some(value)` but as just `value`. [`Bools`](bool) are +//! that [`syn`] supports like [`LitStr`](struct@LitStr) or [`Type`] or that +//! have a direct equivalent in those like [`String`], [`char`] or [`f32`]. A +//! special treatment have [`Vecs`](Vec) which are parsed as either `name = [a, +//! b, c]` or `name(a, b, c)` and [`Options`](Option) that will be [`None`] if +//! not specified and [`Some`] when the value is specified via the attribute. It +//! is not specified via `Some(value)` but as just `value`. [`Bools`](bool) are //! used for flags, i.e., without a value. Most should just behave as expected, //! see [`parsing`] for details. //! +//! Tuple structs can derive [`FromAttr`] as well, but all fields will be +//! positional. Tuples with a single field +//! ([new types](https://rust-unofficial.github.io/patterns/patterns/behavioural/newtype.html)) +//! will copy the behavior of the contained field, e.g. for [`bool`]: +//! +//! ``` +//! use attribute_derive::FromAttr; +//! +//! #[derive(FromAttr, PartialEq, Debug)] +//! #[attribute(ident = flag)] +//! struct Flag(bool); +//! +//! let attr: Attribute = parse_quote!(#[flag]); +//! assert_eq!(Flag::from_attribute(attr).unwrap(), Flag(true)); +//! +//! let attr: Attribute = parse_quote!(#[flag = true]); +//! assert_eq!(Flag::from_attribute(attr).unwrap(), Flag(true)); +//! +//! let attr: Attribute = parse_quote!(#[flag(false)]); +//! assert_eq!(Flag::from_attribute(attr).unwrap(), Flag(false)); +//! ``` +//! //! # Attributes //! //! The parsing of attributes can be modified with the following parameters via //! the `#[attribute()]` attribute. All of them are optional. Error -//! messages are formatted using [interpolator], and only support display and on +//! messages are formatted using [interpolator], and only support display and //! lists `i` formatting. See [interpolator] docs for details. //! //! ### Struct diff --git a/tests/derive.rs b/tests/derive.rs index 4866107..b3aae9c 100644 --- a/tests/derive.rs +++ b/tests/derive.rs @@ -219,3 +219,40 @@ fn literal_attrs() { .is_flag() ); } + +#[test] +fn tuple() { + #[derive(FromAttr, PartialEq, Debug)] + #[attribute(ident = flag)] + struct Flag(bool); + + let attr: Attribute = parse_quote!(#[flag]); + assert_eq!(Flag::from_attribute(attr).unwrap(), Flag(true)); + let attr: Attribute = parse_quote!(#[flag = true]); + assert_eq!(Flag::from_attribute(attr).unwrap(), Flag(true)); + let attr: Attribute = parse_quote!(#[flag(false)]); + assert_eq!(Flag::from_attribute(attr).unwrap(), Flag(false)); + + #[derive(FromAttr, PartialEq, Debug)] + #[attribute(ident = name_value)] + struct NameValue(String); + let attr: Attribute = parse_quote!(#[name_value = "value"]); + assert_eq!( + NameValue::from_attribute(attr).unwrap(), + NameValue("value".into()) + ); + let attr: Attribute = parse_quote!(#[name_value("value")]); + assert_eq!( + NameValue::from_attribute(attr).unwrap(), + NameValue("value".into()) + ); + + #[derive(FromAttr, PartialEq, Debug)] + #[attribute(ident = multiple)] + struct Multiple(bool, String); + let attr: Attribute = parse_quote!(#[multiple(true, "value")]); + assert_eq!( + Multiple::from_attribute(attr).unwrap(), + Multiple(true, "value".into()) + ); +}