Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions trait-variant/examples/variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,18 @@ where
fn build<T: Display>(&self, items: impl Iterator<Item = T>) -> Self::B<T>;
}

#[trait_variant::make(Send + Sync)]
pub trait GenericTraitWithBounds<'x, S: Sync, Y, const X: usize>
where
Y: Sync,
{
const CONST: usize = 3;
type F;
type A<const ANOTHER_CONST: u8>;
type B<T: Display>: FromIterator<T>;

async fn take(&self, s: S);
fn build<T: Display>(&self, items: impl Iterator<Item = T>) -> Self::B<T>;
}

fn main() {}
19 changes: 16 additions & 3 deletions trait-variant/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ mod variant;
/// fn` and/or `-> impl Trait` return types.
///
/// ```
/// #[trait_variant::make(IntFactory: Send)]
/// trait LocalIntFactory {
/// #[trait_variant::make(Send)]
/// trait IntFactory {
/// async fn make(&self) -> i32;
/// fn stream(&self) -> impl Iterator<Item = i32>;
/// fn call(&self) -> u32;
/// }
/// ```
///
/// The above example causes a second trait called `IntFactory` to be created:
/// The above example causes the trait to be rewritten as:
///
/// ```
/// # use core::future::Future;
Expand All @@ -35,6 +35,19 @@ mod variant;
///
/// Note that ordinary methods such as `call` are not affected.
///
/// If you want to preserve an original trait untouched, `make` can be used to create a new trait with bounds on `async
/// fn` and/or `-> impl Trait` return types.
///
/// ```
/// #[trait_variant::make(IntFactory: Send)]
/// trait LocalIntFactory {
/// async fn make(&self) -> i32;
/// fn stream(&self) -> impl Iterator<Item = i32>;
/// fn call(&self) -> u32;
/// }
/// ```
///
/// The example causes a second trait called `IntFactory` to be created.
/// Implementers of the trait can choose to implement the variant instead of the
/// original trait. The macro creates a blanket impl which ensures that any type
/// which implements the variant also implements the original trait.
Expand Down
111 changes: 59 additions & 52 deletions trait-variant/src/variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,33 @@ impl Parse for Attrs {
}
}

struct MakeVariant {
name: Ident,
#[allow(unused)]
colon: Token![:],
bounds: Punctuated<TraitBound, Plus>,
enum MakeVariant {
// Creates a variant of a trait under a new name with additional bounds while preserving the original trait.
Create {
name: Ident,
_colon: Token![:],
bounds: Punctuated<TraitBound, Plus>,
},
// Rewrites the original trait into a new trait with additional bounds.
Rewrite {
bounds: Punctuated<TraitBound, Plus>,
},
}

impl Parse for MakeVariant {
fn parse(input: ParseStream) -> Result<Self> {
Ok(Self {
name: input.parse()?,
colon: input.parse()?,
bounds: input.parse_terminated(TraitBound::parse, Token![+])?,
})
let variant = if input.peek(Ident) && input.peek2(Token![:]) {
MakeVariant::Create {
name: input.parse()?,
_colon: input.parse()?,
bounds: input.parse_terminated(TraitBound::parse, Token![+])?,
}
} else {
MakeVariant::Rewrite {
bounds: input.parse_terminated(TraitBound::parse, Token![+])?,
}
};
Ok(variant)
}
}

Expand All @@ -56,43 +69,51 @@ pub fn make(
let attrs = parse_macro_input!(attr as Attrs);
let item = parse_macro_input!(item as ItemTrait);

let maybe_allow_async_lint = if attrs
.variant
.bounds
.iter()
.any(|b| b.path.segments.last().unwrap().ident == "Send")
{
quote! { #[allow(async_fn_in_trait)] }
} else {
quote! {}
};
match attrs.variant {
MakeVariant::Create { name, bounds, .. } => {
let maybe_allow_async_lint = if bounds
.iter()
.any(|b| b.path.segments.last().unwrap().ident == "Send")
{
quote! { #[allow(async_fn_in_trait)] }
} else {
quote! {}
};

let variant = mk_variant(&attrs, &item);
let blanket_impl = mk_blanket_impl(&attrs, &item);
let variant = mk_variant(&name, bounds, &item);
let blanket_impl = mk_blanket_impl(&name, &item);

quote! {
#maybe_allow_async_lint
#item
quote! {
#maybe_allow_async_lint
#item

#variant
#variant

#blanket_impl
#blanket_impl
}
.into()
}
MakeVariant::Rewrite { bounds, .. } => {
let variant = mk_variant(&item.ident, bounds, &item);
quote! {
#variant
}
.into()
}
}
.into()
}

fn mk_variant(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
let MakeVariant {
ref name,
colon: _,
ref bounds,
} = attrs.variant;
let bounds: Vec<_> = bounds
fn mk_variant(
variant: &Ident,
with_bounds: Punctuated<TraitBound, Plus>,
tr: &ItemTrait,
) -> TokenStream {
let bounds: Vec<_> = with_bounds
.into_iter()
.map(|b| TypeParamBound::Trait(b.clone()))
.collect();
let variant = ItemTrait {
ident: name.clone(),
ident: variant.clone(),
supertraits: tr.supertraits.iter().chain(&bounds).cloned().collect(),
items: tr
.items
Expand All @@ -104,21 +125,8 @@ fn mk_variant(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
quote! { #variant }
}

// Transforms one item declaration within the definition if it has `async fn` and/or `-> impl Trait` return types by adding new bounds.
fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
// #[make_variant(SendIntFactory: Send)]
// trait IntFactory {
// async fn make(&self, x: u32, y: &str) -> i32;
// fn stream(&self) -> impl Iterator<Item = i32>;
// fn call(&self) -> u32;
// }
//
// becomes:
//
// trait SendIntFactory: Send {
// fn make(&self, x: u32, y: &str) -> impl ::core::future::Future<Output = i32> + Send;
// fn stream(&self) -> impl Iterator<Item = i32> + Send;
// fn call(&self) -> u32;
// }
let TraitItem::Fn(fn_item @ TraitItemFn { sig, .. }) = item else {
return item.clone();
};
Expand Down Expand Up @@ -160,9 +168,8 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
})
}

fn mk_blanket_impl(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
fn mk_blanket_impl(variant: &Ident, tr: &ItemTrait) -> TokenStream {
let orig = &tr.ident;
let variant = &attrs.variant.name;
let (_impl, orig_ty_generics, _where) = &tr.generics.split_for_impl();
let items = tr
.items
Expand Down