diff --git a/gcmodule_derive/Cargo.toml b/gcmodule_derive/Cargo.toml index 7a93b03..2061481 100644 --- a/gcmodule_derive/Cargo.toml +++ b/gcmodule_derive/Cargo.toml @@ -13,6 +13,7 @@ proc-macro = true [dependencies] quote = "1" syn = { version = "1", features = ["derive"] } +proc-macro2 = "1.0.32" [dev-dependencies] -gcmodule = { path = ".." } \ No newline at end of file +gcmodule = { path = ".." } diff --git a/gcmodule_derive/src/lib.rs b/gcmodule_derive/src/lib.rs index 1b14083..48ff3cb 100644 --- a/gcmodule_derive/src/lib.rs +++ b/gcmodule_derive/src/lib.rs @@ -20,82 +20,279 @@ extern crate proc_macro; use proc_macro::TokenStream; -use quote::quote; -use quote::ToTokens; -use syn::Data; +use proc_macro2::TokenStream as TokenStream2; +use quote::{format_ident, quote}; +use syn::{ + parenthesized, + parse::{Parse, ParseStream}, + parse_macro_input, + spanned::Spanned, + Attribute, Data, DeriveInput, Error, Field, Fields, Ident, Path, Result, +}; -#[proc_macro_derive(Trace, attributes(trace))] -pub fn gcmodule_trace_derive(input: TokenStream) -> TokenStream { - let input = syn::parse_macro_input!(input as syn::DeriveInput); - let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); - let ident = input.ident; - let mut trace_fn_body = Vec::new(); - let mut is_type_tracked_fn_body = Vec::new(); - if !input.attrs.into_iter().any(is_skipped) { - match input.data { - Data::Struct(data) => { - for (i, field) in data.fields.into_iter().enumerate() { - if field.attrs.into_iter().any(is_skipped) { - continue; - } - let trace_field = match field.ident { - Some(i) => quote! { - if gcmodule::DEBUG_ENABLED { - eprintln!("[gc] Trace({}): visit .{}", stringify!(#ident), stringify!(#i)); - } - self.#i.trace(tracer); - }, - None => { - let i = syn::Index::from(i); - quote! { - if gcmodule::DEBUG_ENABLED { - eprintln!("[gc] Trace({}): visit .{}", stringify!(#ident), stringify!(#i)); - } - self.#i.trace(tracer); - } - } - }; - trace_fn_body.push(trace_field); - let ty = field.ty; - is_type_tracked_fn_body.push(quote! { - if <#ty as _gcmodule::Trace>::is_type_tracked() { - return true; - } - }); - } +mod kw { + syn::custom_keyword!(trace); + syn::custom_keyword!(skip); + syn::custom_keyword!(with); + syn::custom_keyword!(tracking); + syn::custom_keyword!(ignore); + syn::custom_keyword!(force); +} + +enum TraceAttr { + Skip, + With(Path), + TrackingForce(bool), +} +impl TraceAttr { + fn force_is_type_tracked(&self) -> Option { + match self { + Self::TrackingForce(v) => Some(quote! {#v}), + Self::Skip => Some(quote! {false}), + Self::With(_) => Some(quote! {true}), + } + } +} +impl Parse for TraceAttr { + fn parse(input: ParseStream) -> Result { + let lookahead = input.lookahead1(); + if lookahead.peek(kw::skip) { + input.parse::()?; + Ok(Self::Skip) + } else if lookahead.peek(kw::tracking) { + input.parse::()?; + let content; + parenthesized!(content in input); + let lookahead = content.lookahead1(); + if lookahead.peek(kw::ignore) { + content.parse::()?; + Ok(Self::TrackingForce(false)) + } else if lookahead.peek(kw::force) { + content.parse::()?; + Ok(Self::TrackingForce(true)) + } else { + Err(lookahead.error()) + } + } else if lookahead.peek(kw::with) { + input.parse::()?; + let content; + parenthesized!(content in input); + Ok(Self::With(content.parse()?)) + } else { + Err(lookahead.error()) + } + } +} + +fn parse_attr(attrs: &[Attribute], ident: I) -> Result> +where + Ident: PartialEq, +{ + let attrs = attrs + .iter() + .filter(|a| a.path.is_ident(&ident)) + .collect::>(); + if attrs.len() > 1 { + return Err(Error::new( + attrs[1].span(), + "this attribute may be specified only once", + )); + } else if attrs.is_empty() { + return Ok(None); + } + let attr = attrs[0]; + let attr = attr.parse_args::()?; + + Ok(Some(attr)) +} + +/// Returns impl for (trace, is_type_tracked) +fn derive_fields( + trace_attr: &Option, + fields: &Fields, +) -> Result<(TokenStream2, TokenStream2)> { + fn inner<'a>(names: &[Ident], fields: Vec<&Field>) -> Result<(TokenStream2, TokenStream2)> { + let attrs = fields + .iter() + .map(|f| parse_attr::(&f.attrs, "trace")) + .collect::>>()?; + + let trace = names.iter().zip(attrs.iter()).filter_map(|(name, attr)| { + match attr { + Some(TraceAttr::Skip) => return None, + Some(TraceAttr::With(w)) => return Some(quote! {#w(#name, tracer)}), + _ => {} } - Data::Enum(_) | Data::Union(_) => { - trace_fn_body.push(quote! { - compile_error!("enum or union are not supported"); - }); + Some(quote! { + ::gcmodule::Trace::trace(#name, tracer) + }) + }); + let is_type_tracked = fields.iter().zip(attrs.iter()).filter_map(|(field, attr)| { + match attr { + Some(TraceAttr::Skip | TraceAttr::TrackingForce(false)) => return None, + Some(TraceAttr::With(_) | TraceAttr::TrackingForce(true)) => { + return Some(quote! {true}) + } + _ => {} } + let ty = &field.ty; + Some(quote! { + <#ty as ::gcmodule::Trace>::is_type_tracked() + }) + }); + + let trace = quote! { + #(#trace;)* }; + + Ok(( + trace, + quote! { + #(if #is_type_tracked {return true;})* + }, + )) } - let generated = quote! { - const _: () = { - extern crate gcmodule as _gcmodule; - impl #impl_generics _gcmodule::Trace for #ident #ty_generics #where_clause { - fn trace(&self, tracer: &mut _gcmodule::Tracer) { - #( #trace_fn_body )* + match fields { + Fields::Named(named) => { + if matches!(trace_attr, Some(TraceAttr::Skip)) { + return Ok(( + quote! { + {...} => {} + }, + quote! {}, + )); + } + let force_is_type_tracked = trace_attr.as_ref().and_then(|a| a.force_is_type_tracked()); + + let names = named + .named + .iter() + .map(|i| i.ident.clone().unwrap()) + .collect::>(); + let (trace, is_type_tracked) = inner(&names, named.named.iter().collect())?; + let is_type_tracked = force_is_type_tracked.unwrap_or(is_type_tracked); + + Ok(( + quote! { + {#(#names),*} => {#trace} + }, + is_type_tracked, + )) + } + Fields::Unnamed(unnamed) => { + if matches!(trace_attr, Some(TraceAttr::Skip)) { + return Ok((quote! {(...) => {}}, quote! {})); + } + let force_is_type_tracked = trace_attr.as_ref().and_then(|a| a.force_is_type_tracked()); + + let names = (0..unnamed.unnamed.len()) + .map(|i| format_ident!("field_{}", i)) + .collect::>(); + let (trace, is_type_tracked) = inner(&names, unnamed.unnamed.iter().collect())?; + let is_type_tracked = force_is_type_tracked.unwrap_or(is_type_tracked); + + Ok(( + quote! { + (#(#names,)*) => {#trace} + }, + is_type_tracked, + )) + } + Fields::Unit => Ok(( + quote! { + => {} + }, + quote! {}, + )), + } +} + +fn derive_trace(input: DeriveInput) -> Result { + let trace_attr = parse_attr::(&input.attrs, "trace")?; + if matches!(trace_attr, Some(TraceAttr::With(_))) { + return Err(Error::new(input.span(), "implement Trace instead")); + } + let ident = &input.ident; + let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl(); + if matches!(trace_attr, Some(TraceAttr::Skip)) { + return Ok(quote! { + impl #impl_generics ::gcmodule::Trace for #ident #type_generics #where_clause { + fn trace(&self, _tracer: &mut ::gcmodule::Tracer) { } fn is_type_tracked() -> bool { - #( #is_type_tracked_fn_body )* false } } - }; - }; - generated.into() -} + }); + } + let force_is_type_tracked = trace_attr.and_then(|a| a.force_is_type_tracked()); + let (trace, is_type_tracked) = match &input.data { + Data::Struct(s) => { + let (trace, is_type_tracked) = derive_fields(&None, &s.fields)?; -fn is_skipped(attr: syn::Attribute) -> bool { - // check if `#[trace(skip)]` exists. - if attr.path.to_token_stream().to_string() == "trace" { - for token in attr.tokens { - if token.to_string() == "(skip)" { - return true; + ( + quote! { + Self#trace + }, + quote! { + #is_type_tracked + false + }, + ) + } + Data::Enum(e) if e.variants.is_empty() => (quote! {_=>unreachable!()}, quote! {false}), + Data::Enum(e) => { + let variants = e + .variants + .iter() + .map(|v| { + let name = &v.ident; + let attr = parse_attr::(&v.attrs, "trace")?; + let impls = derive_fields(&attr, &v.fields)?; + Ok((name, impls)) as Result<_> + }) + .collect::>>()?; + + let trace = variants.iter().map(|(name, (trace, _))| { + quote! { + Self::#name #trace + } + }); + let is_type_tracked = variants.iter().map(|(_, (_, v))| v); + + ( + quote! { + #(#trace),* + }, + quote! { + #(#is_type_tracked)* + false + }, + ) + } + + Data::Union(_) => return Err(Error::new(input.span(), "union is not supported")), + }; + let is_type_tracked = force_is_type_tracked.unwrap_or(is_type_tracked); + Ok(quote! { + impl #impl_generics ::gcmodule::Trace for #ident #type_generics #where_clause { + fn trace(&self, tracer: &mut ::gcmodule::Tracer) { + match self { + #trace + } + } + fn is_type_tracked() -> bool { + #is_type_tracked } } + }) +} + +#[proc_macro_derive(Trace, attributes(trace))] +pub fn derive_trace_real(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + match derive_trace(input) { + Ok(v) => v.into(), + Err(e) => e.to_compile_error().into(), } - false } diff --git a/gcmodule_derive/tests/trace.rs b/gcmodule_derive/tests/trace.rs index cf87809..882f806 100644 --- a/gcmodule_derive/tests/trace.rs +++ b/gcmodule_derive/tests/trace.rs @@ -1,8 +1,22 @@ -use gcmodule::{Cc, Trace}; +use gcmodule::{Cc, Trace, Tracer}; use gcmodule_derive::Trace as DeriveTrace; use std::cell::RefCell; use std::rc::Rc; +#[test] +fn test_empty() { + #[derive(DeriveTrace)] + struct S0; + + #[derive(DeriveTrace)] + enum E0 {} + + #[derive(DeriveTrace)] + enum E1 { + _A, + } +} + #[test] fn test_named_struct() { #[derive(DeriveTrace)] @@ -75,6 +89,31 @@ fn test_container_skip() { assert!(!E0::is_type_tracked()); } +#[test] +fn test_recursive_struct() { + #[derive(DeriveTrace)] + struct A { + b: Box, + #[trace(tracking(ignore))] + a: Box, + } + assert!(A::is_type_tracked()); + + #[derive(DeriveTrace)] + struct B { + #[trace(tracking(ignore))] + b: Box, + } + assert!(!B::is_type_tracked()); + + #[derive(DeriveTrace)] + #[trace(tracking(force))] + struct C { + c: (Box, Box), + } + assert!(C::is_type_tracked()); +} + #[test] fn test_unnamed_struct() { #[derive(DeriveTrace)] @@ -100,3 +139,13 @@ fn test_real_cycles() { } assert_eq!(gcmodule::collect_thread_cycles(), 3); } + +#[test] +fn test_with() { + struct Child; + + fn trace_child(_child: &Child, _tracer: &mut Tracer) {} + + #[derive(DeriveTrace)] + struct Parent(#[trace(with(trace_child))] Child); +} diff --git a/src/lib.rs b/src/lib.rs index 2a0a5c7..b5380c4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -131,7 +131,7 @@ //! assert_eq!(gcmodule::count_thread_tracked(), 1); //! ``` //! -//! The `#[trace(skip)]` attribute can be used to skip tracking specified fields +//! The `#[skip_trace]` attribute can be used to skip tracking specified fields //! in a structure. //! //! ```