diff --git a/clap_derive/src/derives/args.rs b/clap_derive/src/derives/args.rs index d6fecc01970..7cd5f50d56f 100644 --- a/clap_derive/src/derives/args.rs +++ b/clap_derive/src/derives/args.rs @@ -18,6 +18,7 @@ use syn::{ punctuated::Punctuated, spanned::Spanned, token::Comma, Data, DataStruct, DeriveInput, Field, Fields, FieldsNamed, Generics, }; +use syn::{DataEnum, Variant}; use crate::item::{Item, Kind, Name}; use crate::utils::{inner_type, sub_type, Sp, Ty}; @@ -51,10 +52,188 @@ pub(crate) fn derive_args(input: &DeriveInput) -> Result, syn::Error>>()?; gen_for_struct(&item, ident, &input.generics, &fields) } + Data::Enum(DataEnum { ref variants, .. }) => { + let name = Name::Derived(ident.clone()); + let item = Item::from_args_struct(input, name)?; + + let variant_items = variants + .iter() + .map(|variant| { + let item = + Item::from_args_enum_variant(variant, item.casing(), item.env_casing())?; + Ok((item, variant)) + }) + .collect::, syn::Error>>()?; + + gen_for_enum(&item, ident, &input.generics, &variant_items) + } _ => abort_call_site!("`#[derive(Args)]` only supports non-tuple structs"), } } +pub(crate) fn gen_for_enum( + _item: &Item, + item_name: &Ident, + generics: &Generics, + variants: &[(Item, &Variant)], +) -> Result { + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + let app_var = Ident::new("__clap_app", Span::call_site()); + let mut augmentations = TokenStream::default(); + let mut augmentations_update = TokenStream::default(); + + let mut constructors = TokenStream::default(); + let mut updaters = TokenStream::default(); + + for (item, variant) in variants.iter() { + let Fields::Named(ref fields) = variant.fields else { + abort! { variant.span(), + "`#[derive(Args)]` only supports named enum variants if used on an enum", + } + }; + let group_id = item.group_id(); + + let conflicts = variants + .iter() + .filter_map(|(_, v)| { + if v.ident == variant.ident { + None + } else { + Some(Name::Derived(v.ident.clone())) + } + }) + .collect::>(); + + let fields = collect_args_fields(item, fields)?; + + let augmentation = gen_augment(&fields, &app_var, item, &conflicts, false)?; + let augmentation = quote! { + let #app_var = #augmentation; + }; + let augmentation_update = gen_augment(&fields, &app_var, item, &conflicts, true)?; + let augmentation_update = quote! { + let #app_var = #augmentation_update; + }; + + augmentations.extend(augmentation); + augmentations_update.extend(augmentation_update); + + let variant_name = &variant.ident; + let genned_constructor = gen_constructor(&fields)?; + let constructor = quote! { + if __clap_arg_matches.contains_id(#group_id) { + let v = #item_name::#variant_name #genned_constructor; + return ::std::result::Result::Ok(v) + } + }; + + constructors.extend(constructor); + + let genned_updater = gen_updater(&fields, false)?; + + let field_names = fields + .iter() + .map(|(field, _)| field.ident.as_ref().unwrap()); + let updater = quote! { + + if __clap_arg_matches.contains_id(#group_id) { + let #item_name::#variant_name { #( #field_names ),* } = self else { + unreachable!(); + }; + #genned_updater; + } + }; + + updaters.extend(updater); + } + + let raw_deprecated = raw_deprecated(); + + Ok(quote! { + #[allow( + dead_code, + unreachable_code, + unused_variables, + unused_braces, + unused_qualifications, + )] + #[allow( + clippy::style, + clippy::complexity, + clippy::pedantic, + clippy::restriction, + clippy::perf, + clippy::deprecated, + clippy::nursery, + clippy::cargo, + clippy::suspicious_else_formatting, + clippy::almost_swapped, + clippy::redundant_locals, + )] + #[automatically_derived] + impl #impl_generics clap::FromArgMatches for #item_name #ty_generics #where_clause { + fn from_arg_matches(__clap_arg_matches: &clap::ArgMatches) -> ::std::result::Result { + Self::from_arg_matches_mut(&mut __clap_arg_matches.clone()) + } + + fn from_arg_matches_mut(__clap_arg_matches: &mut clap::ArgMatches) -> ::std::result::Result { + #raw_deprecated + #constructors + unreachable!() + } + + fn update_from_arg_matches(&mut self, __clap_arg_matches: &clap::ArgMatches) -> ::std::result::Result<(), clap::Error> { + self.update_from_arg_matches_mut(&mut __clap_arg_matches.clone()) + } + + fn update_from_arg_matches_mut(&mut self, __clap_arg_matches: &mut clap::ArgMatches) -> ::std::result::Result<(), clap::Error> { + #raw_deprecated + #updaters + ::std::result::Result::Ok(()) + } + } + + + #[allow( + dead_code, + unreachable_code, + unused_variables, + unused_braces, + unused_qualifications, + )] + #[allow( + clippy::style, + clippy::complexity, + clippy::pedantic, + clippy::restriction, + clippy::perf, + clippy::deprecated, + clippy::nursery, + clippy::cargo, + clippy::suspicious_else_formatting, + clippy::almost_swapped, + clippy::redundant_locals, + )] + #[automatically_derived] + impl #impl_generics clap::Args for #item_name #ty_generics #where_clause { + fn group_id() -> Option { + // todo: how does this interact with nested groups here + None + } + fn augment_args<'b>(#app_var: clap::Command) -> clap::Command { + #augmentations + #app_var + } + fn augment_args_for_update<'b>(#app_var: clap::Command) -> clap::Command { + #augmentations_update + #app_var + } + } + + }) +} + pub(crate) fn gen_for_struct( item: &Item, item_name: &Ident, @@ -75,8 +254,8 @@ pub(crate) fn gen_for_struct( let raw_deprecated = raw_deprecated(); let app_var = Ident::new("__clap_app", Span::call_site()); - let augmentation = gen_augment(fields, &app_var, item, false)?; - let augmentation_update = gen_augment(fields, &app_var, item, true)?; + let augmentation = gen_augment(fields, &app_var, item, &[], false)?; + let augmentation_update = gen_augment(fields, &app_var, item, &[], true)?; let group_id = if item.skip_group() { quote!(None) @@ -170,6 +349,9 @@ pub(crate) fn gen_augment( fields: &[(&Field, Item)], app_var: &Ident, parent_item: &Item, + // when generating mutably exclusive arguments, + // ids of arguments that should conflict + conflicts: &[Name], override_required: bool, ) -> Result { let mut subcommand_specified = false; @@ -420,12 +602,25 @@ pub(crate) fn gen_augment( let group_methods = parent_item.group_methods(); + let conflicts_method = if conflicts.is_empty() { + quote!() + } else { + let conflicts_len = conflicts.len(); + quote! { + .conflicts_with_all({ + let conflicts: [clap::Id; #conflicts_len] = [#( clap::Id::from(#conflicts) ),* ]; + conflicts + }) + } + }; + quote!( .group( clap::ArgGroup::new(#group_id) .multiple(true) #group_methods .args(#literal_group_members) + #conflicts_method ) ) }; diff --git a/clap_derive/src/derives/subcommand.rs b/clap_derive/src/derives/subcommand.rs index 6853b65d39d..d162e5b5126 100644 --- a/clap_derive/src/derives/subcommand.rs +++ b/clap_derive/src/derives/subcommand.rs @@ -282,7 +282,7 @@ fn gen_augment( Named(ref fields) => { // Defer to `gen_augment` for adding cmd methods let fields = collect_args_fields(item, fields)?; - args::gen_augment(&fields, &subcommand_var, item, override_required)? + args::gen_augment(&fields, &subcommand_var, item, &[], override_required)? } Unit => { let arg_block = quote!( #subcommand_var ); diff --git a/clap_derive/src/item.rs b/clap_derive/src/item.rs index e48b200f704..6f5fc8dabde 100644 --- a/clap_derive/src/item.rs +++ b/clap_derive/src/item.rs @@ -70,6 +70,62 @@ impl Item { Ok(res) } + pub(crate) fn from_args_enum_variant( + variant: &Variant, + argument_casing: Sp, + env_casing: Sp, + ) -> Result { + let name = variant.ident.clone(); + let ident = variant.ident.clone(); + let span = variant.span(); + + let ty = match variant.fields { + syn::Fields::Unnamed(syn::FieldsUnnamed { ref unnamed, .. }) if unnamed.len() == 1 => { + Ty::from_syn_ty(&unnamed[0].ty) + } + syn::Fields::Named(_) | syn::Fields::Unnamed(..) | syn::Fields::Unit => { + Sp::new(Ty::Other, span) + } + }; + let kind = Sp::new(Kind::Command(ty), span); + let mut res = Self::new( + Name::Derived(name), + ident, + None, + argument_casing, + env_casing, + kind, + ); + let parsed_attrs = ClapAttr::parse_all(&variant.attrs)?; + res.infer_kind(&parsed_attrs)?; + res.push_attrs(&parsed_attrs)?; + if matches!(&*res.kind, Kind::Command(_) | Kind::Subcommand(_)) { + res.push_doc_comment(&variant.attrs, "about", Some("long_about")); + } + + // TODO: ??? + match &*res.kind { + Kind::Flatten(_) => { + if res.has_explicit_methods() { + abort!( + res.kind.span(), + "methods are not allowed for flattened entry" + ); + } + } + + Kind::Subcommand(_) + | Kind::ExternalSubcommand + | Kind::FromGlobal(_) + | Kind::Skip(_, _) + | Kind::Command(_) + | Kind::Value + | Kind::Arg(_) => (), + } + + Ok(res) + } + pub(crate) fn from_subcommand_enum( input: &DeriveInput, name: Name, diff --git a/flake.lock b/flake.lock new file mode 100644 index 00000000000..64b32ad09d0 --- /dev/null +++ b/flake.lock @@ -0,0 +1,63 @@ +{ + "nodes": { + "fenix": { + "inputs": { + "nixpkgs": [ + "nixpkgs" + ], + "rust-analyzer-src": "rust-analyzer-src" + }, + "locked": { + "lastModified": 1724567349, + "narHash": "sha256-w2G1EJlGvgRSC1OAm2147mCzlt6ZOWIiqX/TSJUgrGE=", + "owner": "nix-community", + "repo": "fenix", + "rev": "71fe264f6e208831aa0e7e54ad557a283c375014", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "fenix", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1724015816, + "narHash": "sha256-hVESnM7Eiz93+4DeiE0a1TwMeaeph1ytRJ5QtqxYRWg=", + "path": "/nix/store/d87mpjlsl8flcgazpjnhimq21y46a3g4-source", + "rev": "9aa35efbea27d320d0cdc5f922f0890812affb60", + "type": "path" + }, + "original": { + "id": "nixpkgs", + "type": "indirect" + } + }, + "root": { + "inputs": { + "fenix": "fenix", + "nixpkgs": "nixpkgs" + } + }, + "rust-analyzer-src": { + "flake": false, + "locked": { + "lastModified": 1724480527, + "narHash": "sha256-C+roFDGk6Bn/C58NGpyt7cneLCetdRMUfFTkm3O4zWM=", + "owner": "rust-lang", + "repo": "rust-analyzer", + "rev": "74a6427861eb8d1e3b7c6090b2c2890ff4c53e0e", + "type": "github" + }, + "original": { + "owner": "rust-lang", + "ref": "nightly", + "repo": "rust-analyzer", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 00000000000..569f4b93d13 --- /dev/null +++ b/flake.nix @@ -0,0 +1,36 @@ +{ + inputs.fenix = { + url = "github:nix-community/fenix"; + inputs.nixpkgs.follows = "nixpkgs"; + }; + + outputs = + { nixpkgs, fenix, ... }: + let + pkgs = nixpkgs.legacyPackages.aarch64-darwin; + in + { + devShells.aarch64-darwin.default = pkgs.mkShell { + shellHook = '' + export RUST_SRC_PATH="${pkgs.rustPlatform.rustLibSrc}"; + ''; + packages = [ + pkgs.lldb + pkgs.libiconv + pkgs.darwin.apple_sdk.frameworks.Security + pkgs.darwin.apple_sdk.frameworks.SystemConfiguration + (fenix.packages.aarch64-darwin.combine [ + fenix.packages.aarch64-darwin.stable.cargo + fenix.packages.aarch64-darwin.stable.rust + fenix.packages.aarch64-darwin.targets.wasm32-wasi.stable.rust-std + fenix.packages.aarch64-darwin.targets.wasm32-wasip1.stable.rust-std + fenix.packages.aarch64-darwin.targets.aarch64-apple-darwin.stable.rust-std + ]) + ]; + }; + + formatter.aarch64-darwin = pkgs.nixfmt-rfc-style; + + }; + +} diff --git a/tests/derive/groups.rs b/tests/derive/groups.rs index 37141968e29..64fe92eb69d 100644 --- a/tests/derive/groups.rs +++ b/tests/derive/groups.rs @@ -239,3 +239,57 @@ For more information, try '--help'. "; assert_output::("test", OUTPUT, true); } + +#[test] +fn enum_groups_1() { + #[derive(Parser, Debug, PartialEq, Eq)] + struct Opt { + #[command(flatten)] + source: Source, + } + + #[derive(clap::Args, Clone, Debug, PartialEq, Eq)] + enum Source { + A { + #[arg(short)] + a: bool, + #[arg(long)] + aaa: bool, + }, + B { + #[arg(short)] + b: bool, + }, + } + + assert_eq!( + Opt { + source: Source::A { + a: true, + aaa: false, + } + }, + Opt::try_parse_from(["test", "-a"]).unwrap() + ); + assert_eq!( + Opt { + source: Source::A { a: true, aaa: true } + }, + Opt::try_parse_from(["test", "-a", "--aaa"]).unwrap() + ); + assert_eq!( + Opt { + source: Source::B { b: true } + }, + Opt::try_parse_from(["test", "-b"]).unwrap() + ); + + assert_eq!( + clap::error::ErrorKind::ArgumentConflict, + Opt::try_parse_from(["test", "-b", "-a"]) + .unwrap_err() + .kind(), + ); + + // assert_eq!( ) +}