diff --git a/cli/cli.go b/cli/cli.go index 40022e3..297eef6 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -196,3 +196,45 @@ func ParseRRTypes(t []string) (map[uint16]bool, error) { } return rrTypes, nil } + +// isBool checks if a flag by a given name is a boolean flag of Flags +func isBool(name string) bool { + v := reflect.ValueOf(Flags{}) + vT := v.Type() + for i := 0; i < v.NumField(); i++ { + if vT.Field(i).Tag.Get("short") == name || vT.Field(i).Tag.Get("long") == name { + return vT.Field(i).Type == reflect.TypeOf(true) + } + } + return false +} + +// AddEqualSigns adds equal signs between flags and their values, ignoring boolean flags +func AddEqualSigns(args []string) []string { + var newArgs []string + skip := false + for i, arg := range args { + if skip { + skip = false + continue + } + + isFlag := arg[0] == '-' + flagName := strings.TrimLeft(arg, "-") + + if isFlag && isBool(flagName) { // Standalone boolean flag + newArgs = append(newArgs, arg) + } else if isFlag && !isBool(flagName) { // Flag with mapping + if i+1 < len(args) && args[i+1][0] != '-' { // If the next argument is not a flag + newArgs = append(newArgs, arg+"="+args[i+1]) + skip = true + } else { // If the next argument is a flag, add the flag as is + newArgs = append(newArgs, arg) + } + } else { // Positional argument + newArgs = append(newArgs, arg) + } + } + + return newArgs +} diff --git a/main.go b/main.go index 06fc074..07ece2b 100644 --- a/main.go +++ b/main.go @@ -207,6 +207,7 @@ func parseServer(s string) (string, transport.Type, error) { // driver is the "main" function for this program that accepts a flag slice for testing func driver(args []string, out io.Writer) error { args = cli.SetFalseBooleans(&opts, args) + args = cli.AddEqualSigns(args) parser := flags.NewParser(&opts, flags.Default) parser.Usage = `[OPTIONS] [@server] [type...] [name] diff --git a/main_test.go b/main_test.go index a7189d1..c775a53 100644 --- a/main_test.go +++ b/main_test.go @@ -533,3 +533,13 @@ func TestMainQueryDomainWithRRType(t *testing.T) { assert.Nil(t, err) assert.Regexp(t, regexp.MustCompile(`NS.network. .* A .*`), out.String()) } + +func TestMainQueryTypeFlag(t *testing.T) { + out, err := run( + "-t", "65", + "cloudflare.com", + "-v", + ) + assert.Nil(t, err) + assert.Regexp(t, regexp.MustCompile(`cloudflare.com. .* HTTPS 1 .*`), out.String()) +}