|
1 | 1 | use proc_macro::TokenStream; |
2 | | -use proc_macro2::Span; |
3 | 2 | use quote::{ToTokens, quote_spanned}; |
4 | 3 | use syn::{ |
5 | | - Error, FnArg, Ident, ItemFn, ReturnType, Stmt, Token, parse::Parse, parse_macro_input, |
6 | | - parse_quote, punctuated::Punctuated, spanned::Spanned, |
| 4 | + FnArg, Ident, ItemFn, ReturnType, Stmt, parse_macro_input, parse_quote, spanned::Spanned, |
7 | 5 | }; |
8 | 6 |
|
9 | 7 | /// Registers a function as a gpu kernel. |
10 | 8 | /// |
11 | 9 | /// This attribute must always be placed on gpu kernel functions. |
12 | 10 | /// |
13 | 11 | /// This attribute does a couple of things: |
14 | | -/// - Tells `rustc_codegen_nvvm` to mark this as a gpu kernel and to not remove it from the ptx file. |
| 12 | +/// - Tells `rustc_codegen_nvvm` to mark this as a gpu kernel and to not remove it from the ptx |
| 13 | +/// file. |
15 | 14 | /// - Marks the function as `no_mangle`. |
16 | 15 | /// - Errors if the function is not unsafe. |
17 | 16 | /// - Makes sure function parameters are all [`Copy`]. |
18 | 17 | /// - Makes sure the function doesn't return anything. |
19 | 18 | /// |
20 | | -/// Note that this does not cfg the function for nvptx(64), that is explicit so that rust analyzer is able to |
21 | | -/// offer intellisense by default. |
| 19 | +/// Note that this does not cfg the function for nvptx(64), that is explicit so that rust analyzer |
| 20 | +/// is able to offer intellisense by default. |
22 | 21 | #[proc_macro_attribute] |
23 | 22 | pub fn kernel(input: proc_macro::TokenStream, item: proc_macro::TokenStream) -> TokenStream { |
24 | 23 | let cloned = input.clone(); |
25 | | - let _ = parse_macro_input!(input as KernelHints); |
26 | 24 | let input = parse_macro_input!(cloned as proc_macro2::TokenStream); |
27 | 25 | let mut item = parse_macro_input!(item as ItemFn); |
28 | 26 | let no_mangle = parse_quote!(#[unsafe(no_mangle)]); |
29 | 27 | item.attrs.push(no_mangle); |
30 | 28 | let internal = parse_quote!(#[cfg_attr(target_arch="nvptx64", nvvm_internal::kernel(#input))]); |
31 | 29 | item.attrs.push(internal); |
32 | 30 |
|
33 | | - // used to guarantee some things about how params are passed in the codegen. |
| 31 | + // Used to guarantee some things about how params are passed in the codegen. |
34 | 32 | item.sig.abi = Some(parse_quote!(extern "C")); |
35 | 33 |
|
36 | 34 | let check_fn = parse_quote! { |
@@ -71,80 +69,10 @@ pub fn kernel(input: proc_macro::TokenStream, item: proc_macro::TokenStream) -> |
71 | 69 | item.to_token_stream().into() |
72 | 70 | } |
73 | 71 |
|
74 | | -#[derive(Debug, Clone, Copy, PartialEq)] |
75 | | -enum Dimension { |
76 | | - Dim1, |
77 | | - Dim2, |
78 | | - Dim3, |
79 | | -} |
80 | | - |
81 | | -impl Parse for Dimension { |
82 | | - fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> { |
83 | | - let val = Ident::parse(input)?; |
84 | | - let val = val.to_string(); |
85 | | - match val.as_str() { |
86 | | - "1d" | "1D" => Ok(Self::Dim1), |
87 | | - "2d" | "2D" => Ok(Self::Dim2), |
88 | | - "3d" | "3D" => Ok(Self::Dim3), |
89 | | - _ => Err(syn::Error::new(Span::call_site(), "Invalid dimension")), |
90 | | - } |
91 | | - } |
92 | | -} |
93 | | - |
94 | | -enum KernelHint { |
95 | | - GridDim(Dimension), |
96 | | - BlockDim(Dimension), |
97 | | -} |
98 | | - |
99 | | -impl Parse for KernelHint { |
100 | | - fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> { |
101 | | - let name = Ident::parse(input)?; |
102 | | - let key = name.to_string(); |
103 | | - <Token![=]>::parse(input)?; |
104 | | - match key.as_str() { |
105 | | - "grid_dim" => { |
106 | | - let dim = Dimension::parse(input)?; |
107 | | - Ok(Self::GridDim(dim)) |
108 | | - } |
109 | | - "block_dim" => { |
110 | | - let dim = Dimension::parse(input)?; |
111 | | - Ok(Self::BlockDim(dim)) |
112 | | - } |
113 | | - _ => Err(Error::new(Span::call_site(), "Unrecognized option")), |
114 | | - } |
115 | | - } |
116 | | -} |
117 | | - |
118 | | -#[derive(Debug, Default, Clone, PartialEq)] |
119 | | -struct KernelHints { |
120 | | - grid_dim: Option<Dimension>, |
121 | | - block_dim: Option<Dimension>, |
122 | | -} |
123 | | - |
124 | | -impl Parse for KernelHints { |
125 | | - fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> { |
126 | | - let iter = Punctuated::<KernelHint, Token![,]>::parse_terminated(input)?; |
127 | | - let hints = iter |
128 | | - .into_pairs() |
129 | | - .map(|x| x.into_value()) |
130 | | - .collect::<Vec<_>>(); |
131 | | - |
132 | | - let mut out = KernelHints::default(); |
133 | | - |
134 | | - for hint in hints { |
135 | | - match hint { |
136 | | - KernelHint::GridDim(dim) => out.grid_dim = Some(dim), |
137 | | - KernelHint::BlockDim(dim) => out.block_dim = Some(dim), |
138 | | - } |
139 | | - } |
140 | | - |
141 | | - Ok(out) |
142 | | - } |
143 | | -} |
144 | | - |
145 | 72 | // derived from rust-gpu's gpu_only |
146 | 73 |
|
147 | | -/// Creates a cpu version of the function which panics and cfg-gates the function for only nvptx/nvptx64. |
| 74 | +/// Creates a cpu version of the function which panics and cfg-gates the function for only |
| 75 | +/// nvptx/nvptx64. |
148 | 76 | #[proc_macro_attribute] |
149 | 77 | pub fn gpu_only(_attr: proc_macro::TokenStream, item: proc_macro::TokenStream) -> TokenStream { |
150 | 78 | let syn::ItemFn { |
|
0 commit comments