Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion builder-pattern-macro/src/attributes.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use bitflags::bitflags;
use syn::{Attribute, Expr, Meta, NestedMeta};
use proc_macro2::TokenTree;
use syn::{Attribute, Expr, Ident, Meta, NestedMeta};

bitflags! {
pub struct Setters: u32 {
Expand All @@ -23,6 +24,17 @@ pub struct FieldAttributes {
pub documents: Vec<Attribute>,
pub setters: Setters,
pub vis: FieldVisibility,
pub late_bound_default: bool,
pub infer: Vec<Ident>,
}

pub fn ident_add_underscore(ident: &Ident) -> Ident {
let ident_ = ident.to_string() + "_";
Ident::new(&ident_, ident.span())
}

pub fn ident_add_underscore_tree(ident: &Ident) -> TokenTree {
TokenTree::Ident(ident_add_underscore(ident))
}

impl Default for FieldAttributes {
Expand All @@ -34,6 +46,8 @@ impl Default for FieldAttributes {
documents: vec![],
setters: Setters::VALUE,
vis: FieldVisibility::Default,
late_bound_default: false,
infer: vec![],
}
}
}
Expand Down Expand Up @@ -62,6 +76,7 @@ impl From<Vec<Attribute>> for FieldAttributes {
unimplemented!("Duplicated `hidden` attributes.")
}
attributes.vis = FieldVisibility::Hidden;
attributes.late_bound_default = true;
} else if attr.path.is_ident("public") {
if attributes.vis != FieldVisibility::Default {
unimplemented!("Duplicated `public` attributes.")
Expand All @@ -75,6 +90,10 @@ impl From<Vec<Attribute>> for FieldAttributes {
attributes.documents = get_documents(&attrs);
} else if attr.path.is_ident("setter") {
parse_setters(attr, &mut attributes)
} else if attr.path.is_ident("infer") {
parse_infer(attr, &mut attributes)
} else if attr.path.is_ident("late_bound_default") {
attributes.late_bound_default = true;
}
});
match attributes.validate() {
Expand Down Expand Up @@ -129,6 +148,28 @@ fn parse_setters(attr: &Attribute, attributes: &mut FieldAttributes) {
attributes.setters = setters;
}

fn parse_infer(attr: &Attribute, attributes: &mut FieldAttributes) {
let meta = attr.parse_meta().unwrap();
let mut params = vec![];
if let Meta::List(l) = meta {
let it = l.nested.iter();
it.for_each(|m| {
if let NestedMeta::Meta(Meta::Path(p)) = m {
if let Some(ident) = p.get_ident() {
params.push(ident.clone());
} else {
unimplemented!("Invalid infer, write a type parameter.")
}
} else {
unimplemented!("Invalid setter.")
}
});
} else {
unimplemented!("Invalid setter.")
}
attributes.infer = params;
}

pub fn get_documents(attrs: &[Attribute]) -> Vec<Attribute> {
let mut documents: Vec<Attribute> = vec![];

Expand Down
7 changes: 4 additions & 3 deletions builder-pattern-macro/src/builder/builder_decl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ impl<'a> ToTokens for BuilderDecl<'a> {
let builder_name = self.input.builder_name();
let where_clause = &self.input.generics.where_clause;

let impl_tokens = self.input.tokenize_impl();
let impl_tokens = self.input.tokenize_impl(&[]);
let all_generics = self.input.all_generics().collect::<Vec<_>>();
let ty_tokens = self.input.tokenize_types();
let ty_tokens = self.input.tokenize_types(&[], false);

let fn_lifetime = self.input.fn_lifetime();
let builder_fields = self.input.builder_fields(&fn_lifetime);
Expand All @@ -41,7 +41,8 @@ impl<'a> ToTokens for BuilderDecl<'a> {
AsyncFieldMarker,
ValidatorOption
> #where_clause {
_phantom: ::core::marker::PhantomData<(
__builder_phantom: ::core::marker::PhantomData<(
&#fn_lifetime (),
#ty_tokens
#(#all_generics,)*
AsyncFieldMarker,
Expand Down
100 changes: 76 additions & 24 deletions builder-pattern-macro/src/builder/builder_functions.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use crate::{
attributes::{FieldVisibility, Setters},
attributes::{ident_add_underscore_tree, FieldVisibility, Setters},
field::Field,
struct_input::StructInput,
};

use core::str::FromStr;
use proc_macro2::{Ident, Span, TokenStream};
use proc_macro2::{Group, Ident, Span, TokenStream, TokenTree};
use quote::ToTokens;
use syn::{parse_quote, spanned::Spanned, Attribute};

Expand All @@ -22,7 +22,21 @@ impl<'a> ToTokens for BuilderFunctions<'a> {
.chain(self.input.optional_fields.iter())
.map(|f| {
let ident = &f.ident;
quote! { #ident: self.#ident }
if f.attrs.late_bound_default {
quote! {
#ident: match self.#ident {
Some(::builder_pattern::setter::Setter::LateBoundDefault(d)) => {
Some(::builder_pattern::setter::Setter::LateBoundDefault(d))
}
Some(::builder_pattern::setter::Setter::Value(val)) => {
Some(::builder_pattern::setter::Setter::Value(val))
}
_ => unreachable!(),
}
}
} else {
quote! { #ident: self.#ident }
}
})
.collect::<Vec<_>>();

Expand Down Expand Up @@ -52,6 +66,25 @@ impl<'a> ToTokens for BuilderFunctions<'a> {
}
}

pub fn replace_type_params_in(
stream: TokenStream,
replacements: &[Ident],
with: &impl Fn(&Ident) -> TokenTree,
) -> TokenStream {
stream
.into_iter()
.map(|tt| match tt {
TokenTree::Group(g) => {
let delim = g.delimiter();
let stream = replace_type_params_in(g.stream(), replacements, with);
TokenTree::Group(Group::new(delim, stream))
}
TokenTree::Ident(ident) if replacements.contains(&ident) => with(&ident),
x => x,
})
.collect()
}

impl<'a> BuilderFunctions<'a> {
pub fn new(input: &'a StructInput) -> Self {
Self { input }
Expand Down Expand Up @@ -101,21 +134,38 @@ impl<'a> BuilderFunctions<'a> {
index: usize,
builder_fields: &mut Vec<TokenStream>,
) {
let (ident, ty, vis) = (&f.ident, &f.ty, &f.vis);
let (ident, orig_ty, vis) = (&f.ident, &f.ty, &f.vis);
let builder_name = self.input.builder_name();
let where_clause = &self.input.generics.where_clause;
let lifetimes = self.input.lifetimes();
let fn_lifetime = self.input.fn_lifetime();
let impl_tokens = self.input.tokenize_impl();
let ty_tokens = self.input.tokenize_types();
let (other_generics, before_generics, after_generics) = self.get_generics(f, index);
let (arg_type_gen, arg_type) = if f.attrs.use_into {
(
Some(quote! {<IntoType: Into<#ty>>}),
TokenStream::from_str("IntoType").unwrap(),
)
let impl_tokens = self.input.tokenize_impl(&[]);
let ty_tokens = self.input.tokenize_types(&[], false);
let ty_tokens_ = self.input.tokenize_types(&f.attrs.infer, false);
let fn_where_clause = self.input.setter_where_clause(&f.attrs.infer);
let (other_generics, before_generics, mut after_generics) = self.get_generics(f, index);
let replaced_ty = replace_type_params_in(
quote! { #orig_ty },
&f.attrs.infer,
&ident_add_underscore_tree,
);
after_generics
.iter_mut()
.for_each(|ty_tokens: &mut TokenStream| {
let tokens = std::mem::take(ty_tokens);
*ty_tokens =
replace_type_params_in(tokens, &f.attrs.infer, &ident_add_underscore_tree);
});
let into_generics = if f.attrs.use_into {
vec![quote! {IntoType: Into<#replaced_ty>}]
} else {
vec![]
};
let fn_generics = f.tokenize_replacement_params(&into_generics);
let arg_type = if f.attrs.use_into {
quote! { IntoType }
} else {
(None, quote! {#ty})
quote! { #replaced_ty }
};
let documents = Self::documents(f, Setters::VALUE);

Expand All @@ -131,7 +181,7 @@ impl<'a> BuilderFunctions<'a> {
Result<#builder_name <
#fn_lifetime,
#(#lifetimes,)*
#ty_tokens
#ty_tokens_
#(#after_generics,)*
AsyncFieldMarker,
ValidatorOption
Expand All @@ -142,7 +192,7 @@ impl<'a> BuilderFunctions<'a> {
match #v (value.into()) {
Ok(value) => Ok(
#builder_name {
_phantom: ::core::marker::PhantomData,
__builder_phantom: ::core::marker::PhantomData,
#(#builder_fields),*
}),
Err(e) => Err(format!("Validation failed: {:?}", e))
Expand All @@ -161,15 +211,15 @@ impl<'a> BuilderFunctions<'a> {
#builder_name <
#fn_lifetime,
#(#lifetimes,)*
#ty_tokens
#ty_tokens_
#(#after_generics,)*
AsyncFieldMarker,
ValidatorOption
>
},
quote! {
#builder_name {
_phantom: ::core::marker::PhantomData,
__builder_phantom: ::core::marker::PhantomData,
#(#builder_fields),*
}
},
Expand All @@ -195,7 +245,9 @@ impl<'a> BuilderFunctions<'a> {
#where_clause
{
#(#documents)*
#vis fn #ident #arg_type_gen(self, value: #arg_type) -> #ret_type {
#vis fn #ident #fn_generics(self, value: #arg_type) -> #ret_type
#fn_where_clause
{
#ret_expr
}
}
Expand All @@ -215,8 +267,8 @@ impl<'a> BuilderFunctions<'a> {
let where_clause = &self.input.generics.where_clause;
let lifetimes = self.input.lifetimes();
let fn_lifetime = self.input.fn_lifetime();
let impl_tokens = self.input.tokenize_impl();
let ty_tokens = self.input.tokenize_types();
let impl_tokens = self.input.tokenize_impl(&[]);
let ty_tokens = self.input.tokenize_types(&[], false);
let (other_generics, before_generics, after_generics) = self.get_generics(f, index);
let arg_type_gen = if f.attrs.use_into {
quote! {<IntoType: Into<#ty>, ValType: #fn_lifetime + ::core::ops::Fn() -> IntoType>}
Expand Down Expand Up @@ -244,7 +296,7 @@ impl<'a> BuilderFunctions<'a> {
};
let ret_expr_val = quote! {
#builder_name {
_phantom: ::core::marker::PhantomData,
__builder_phantom: ::core::marker::PhantomData,
#(#builder_fields),*
}
};
Expand Down Expand Up @@ -305,8 +357,8 @@ impl<'a> BuilderFunctions<'a> {
let where_clause = &self.input.generics.where_clause;
let lifetimes = self.input.lifetimes();
let fn_lifetime = self.input.fn_lifetime();
let impl_tokens = self.input.tokenize_impl();
let ty_tokens = self.input.tokenize_types();
let impl_tokens = self.input.tokenize_impl(&[]);
let ty_tokens = self.input.tokenize_types(&[], false);
let (other_generics, before_generics, after_generics) = self.get_generics(f, index);
let arg_type_gen = if f.attrs.use_into {
quote! {<
Expand Down Expand Up @@ -343,7 +395,7 @@ impl<'a> BuilderFunctions<'a> {
};
let ret_expr_val = quote! {
#builder_name {
_phantom: ::core::marker::PhantomData,
__builder_phantom: ::core::marker::PhantomData,
#(#builder_fields),*
}
};
Expand Down
Loading