Skip to content

Commit c05edf4

Browse files
authored
fix some cli parsing bugs and add unit test coverage (#2)
1 parent 96b4a84 commit c05edf4

File tree

8 files changed

+573
-35
lines changed

8 files changed

+573
-35
lines changed

pkg/cli/cli.go

+58-31
Original file line numberDiff line numberDiff line change
@@ -44,50 +44,91 @@ func New(binaryName string, shortUsage string, longUsage, examples string) Comma
4444
}
4545
}
4646

47-
// ParseAndValidateFlags will parse flags registered in this instance of CLI from os.Args
48-
// and then perform validation
49-
func (cl *CommandLineInterface) ParseAndValidateFlags() (map[string]interface{}, error) {
47+
// ParseFlags will parse flags registered in this instance of CLI from os.Args
48+
func (cl *CommandLineInterface) ParseFlags() (map[string]interface{}, error) {
5049
cl.setUsageTemplate()
5150
// Remove Suite Flags so that args only include Config and Filter Flags
52-
cl.rootCmd.SetArgs(cl.removeIntersectingArgs(cl.suiteFlags))
51+
cl.rootCmd.SetArgs(removeIntersectingArgs(cl.suiteFlags))
5352
// This parses Config and Filter flags only
5453
err := cl.rootCmd.Execute()
5554
if err != nil {
5655
return nil, err
5756
}
5857
// Remove Config and Filter flags so that only suite flags are parsed
59-
err = cl.suiteFlags.Parse(cl.removeIntersectingArgs(cl.rootCmd.Flags()))
58+
err = cl.suiteFlags.Parse(removeIntersectingArgs(cl.rootCmd.Flags()))
6059
if err != nil {
6160
return nil, err
6261
}
6362
// Add suite flags to rootCmd flagset so that other processing can occur
6463
// This has to be done after usage is printed so that the flagsets can be grouped properly when printed
6564
cl.rootCmd.Flags().AddFlagSet(cl.suiteFlags)
66-
err = cl.SetUntouchedFlagValuesToNil()
65+
err = cl.setUntouchedFlagValuesToNil()
6766
if err != nil {
6867
return nil, err
6968
}
70-
err = cl.ProcessRangeFilterFlags()
69+
err = cl.processRangeFilterFlags()
70+
if err != nil {
71+
return nil, err
72+
}
73+
return cl.Flags, nil
74+
}
75+
76+
// ParseAndValidateFlags will parse flags registered in this instance of CLI from os.Args
77+
// and then perform validation
78+
func (cl *CommandLineInterface) ParseAndValidateFlags() (map[string]interface{}, error) {
79+
flags, err := cl.ParseFlags()
7180
if err != nil {
7281
return nil, err
7382
}
7483
err = cl.ValidateFlags()
7584
if err != nil {
7685
return nil, err
7786
}
78-
return cl.Flags, nil
87+
return flags, nil
7988
}
8089

81-
func (cl *CommandLineInterface) removeIntersectingArgs(flagSet *pflag.FlagSet) []string {
82-
newArgs := os.Args[1:]
83-
for i, arg := range newArgs {
84-
if flagSet.Lookup(strings.Replace(arg, "--", "", 1)) != nil || (len(arg) == 2 && flagSet.ShorthandLookup(strings.Replace(arg, "-", "", 1)) != nil) {
85-
newArgs = append(newArgs[:i], newArgs[i+1:]...)
90+
// ValidateFlags iterates through any registered validators and executes them
91+
func (cl *CommandLineInterface) ValidateFlags() error {
92+
for flagName, validationFn := range cl.validators {
93+
if validationFn == nil {
94+
continue
95+
}
96+
err := validationFn(cl.Flags[flagName])
97+
if err != nil {
98+
return err
8699
}
87100
}
101+
return nil
102+
}
103+
104+
func removeIntersectingArgs(flagSet *pflag.FlagSet) []string {
105+
newArgs := []string{}
106+
skipNext := false
107+
for i, arg := range os.Args {
108+
if skipNext {
109+
skipNext = false
110+
continue
111+
}
112+
arg = strings.Split(arg, "=")[0]
113+
longFlag := strings.Replace(arg, "--", "", 1)
114+
if flagSet.Lookup(longFlag) != nil || shorthandLookup(flagSet, arg) != nil {
115+
if len(os.Args) > i+1 && os.Args[i+1][0] != '-' {
116+
skipNext = true
117+
}
118+
continue
119+
}
120+
newArgs = append(newArgs, os.Args[i])
121+
}
88122
return newArgs
89123
}
90124

125+
func shorthandLookup(flagSet *pflag.FlagSet, arg string) *pflag.Flag {
126+
if len(arg) == 2 && arg[0] == '-' && arg[1] != '-' {
127+
return flagSet.ShorthandLookup(strings.Replace(arg, "-", "", 1))
128+
}
129+
return nil
130+
}
131+
91132
func (cl *CommandLineInterface) setUsageTemplate() {
92133
transformedUsage := usageTemplate
93134
suiteFlagCount := 0
@@ -104,9 +145,9 @@ func (cl *CommandLineInterface) setUsageTemplate() {
104145
cl.rootCmd.Flags().Usage = func() {}
105146
}
106147

107-
// SetUntouchedFlagValuesToNil iterates through all flags and sets their value to nil if they were not specifically set by the user
148+
// setUntouchedFlagValuesToNil iterates through all flags and sets their value to nil if they were not specifically set by the user
108149
// This allows for a specified value, a negative value (like false or empty string), or an unspecified (nil) entry.
109-
func (cl *CommandLineInterface) SetUntouchedFlagValuesToNil() error {
150+
func (cl *CommandLineInterface) setUntouchedFlagValuesToNil() error {
110151
defaultHandlerErrMsg := "Unable to find a default value handler for %v, marking as no default value. This could be an error"
111152
defaultHandlerFlags := []string{}
112153

@@ -141,8 +182,8 @@ func (cl *CommandLineInterface) SetUntouchedFlagValuesToNil() error {
141182
return nil
142183
}
143184

144-
// ProcessRangeFilterFlags sets min and max to the appropriate 0 or maxInt bounds based on the 3-tuple that a user specifies for base flag, min, and/or max
145-
func (cl *CommandLineInterface) ProcessRangeFilterFlags() error {
185+
// processRangeFilterFlags sets min and max to the appropriate 0 or maxInt bounds based on the 3-tuple that a user specifies for base flag, min, and/or max
186+
func (cl *CommandLineInterface) processRangeFilterFlags() error {
146187
for flagName := range cl.intRangeFlags {
147188
rangeHelperMin := fmt.Sprintf("%s-%s", flagName, "min")
148189
rangeHelperMax := fmt.Sprintf("%s-%s", flagName, "max")
@@ -167,17 +208,3 @@ func (cl *CommandLineInterface) ProcessRangeFilterFlags() error {
167208
}
168209
return nil
169210
}
170-
171-
// ValidateFlags iterates through any registered validators and executes them
172-
func (cl *CommandLineInterface) ValidateFlags() error {
173-
for flagName, validationFn := range cl.validators {
174-
if validationFn == nil {
175-
continue
176-
}
177-
err := validationFn(cl.Flags[flagName])
178-
if err != nil {
179-
return err
180-
}
181-
}
182-
return nil
183-
}

pkg/cli/cli_internal_test.go

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"). You may
4+
// not use this file except in compliance with the License. A copy of the
5+
// License is located at
6+
//
7+
// http://aws.amazon.com/apache2.0/
8+
//
9+
// or in the "license" file accompanying this file. This file is distributed
10+
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
// express or implied. See the License for the specific language governing
12+
// permissions and limitations under the License.
13+
14+
package cli
15+
16+
import (
17+
"os"
18+
"testing"
19+
20+
h "github.com/aws/amazon-ec2-instance-selector/pkg/test"
21+
"github.com/spf13/pflag"
22+
)
23+
24+
// Tests
25+
26+
func TestRemoveIntersectingArgs(t *testing.T) {
27+
flagSet := pflag.NewFlagSet("test-flag-set", pflag.ContinueOnError)
28+
flagSet.Bool("test-bool", false, "test usage")
29+
os.Args = []string{"ec2-instance-selector", "--test-bool", "--this-should-stay"}
30+
newArgs := removeIntersectingArgs(flagSet)
31+
h.Assert(t, len(newArgs) == 2, "NewArgs should only include the bin name and one argument after removing intersections")
32+
}
33+
34+
func TestRemoveIntersectingArgs_NextArg(t *testing.T) {
35+
flagSet := pflag.NewFlagSet("test-flag-set", pflag.ContinueOnError)
36+
flagSet.String("test-str", "", "test usage")
37+
os.Args = []string{"ec2-instance-selector", "--test-str", "somevalue", "--this-should-stay", "valuetostay"}
38+
newArgs := removeIntersectingArgs(flagSet)
39+
h.Assert(t, len(newArgs) == 3, "NewArgs should only include the bin name and a flag + input after removing intersections")
40+
}
41+
42+
func TestRemoveIntersectingArgs_ShorthandArg(t *testing.T) {
43+
flagSet := pflag.NewFlagSet("test-flag-set", pflag.ContinueOnError)
44+
flagSet.StringP("test-str", "t", "", "test usage")
45+
os.Args = []string{"ec2-instance-selector", "--test-str", "somevalue", "--this-should-stay", "valuetostay", "-t", "test"}
46+
newArgs := removeIntersectingArgs(flagSet)
47+
h.Assert(t, len(newArgs) == 3, "NewArgs should only include the bin name and a flag + input after removing intersections")
48+
}

0 commit comments

Comments
 (0)