diff --git a/README.asciidoc b/README.asciidoc index 30dfcaf2..69638922 100644 --- a/README.asciidoc +++ b/README.asciidoc @@ -84,6 +84,80 @@ Modify the SSH daemon configuration +/etc/ssh/sshd_config+ by adding: ForceCommand /usr/sbin/sshproxy +Migrating to sshproxy 2 +----------------------- + +Version 2 brings a lot of changes to sshproxy: + +1. configuration file: + - **all** configuration options can now be set outside of overrides (those are + the default values) or inside an override + - `users`, `groups` and `routes` options have been replaced by the overrides + system: + * old style: + + routes: + default: + some_default_options… + service1: + source: [an.ip.sshd.listens.to] + some_sources_options… + users: + - alice,bob: + some_users_options… + groups: + - foo,bar: + some_groups_options… + + * new style: + + some_default_options… + overrides: + - match: + - sources: [an.ip.sshd.listens.to] + some_sources_options… + - match: + - users: [alice,bob] + some_users_options… + overrides: + - match: + - groups: [foo,bar] + some_groups_options… + + - the `match` conditions of the overrides system can be combined. Here is an + example meaning "match if (the user is in the group foo **and** in the + group bar) **or** ((the user is alice **or** bob) **and** the user is + connected to an.ip.sshd.listens.to)": + + overrides: + - match: + - groups: [foo] + groups: [bar] + - users: [alice,bob] + sources: [an.ip.sshd.listens.to] + + - nodesets can now be used for the `dest` key + - if `libnodeset.so` (from https://github.com/fdiakh/nodeset-rs) is found, it + allows the use of clustershell groups where nodesets are allowed + - new option: `blocking_command` runs a command before starting the ssh + connection to the destination. If the command does not return 0, the + connection is aborted +2. command line interface: + - in all the tables, `Host` and `Port` columns are now merged into a single + `Host:Port` + - `sshproxyctl get_config` has been removed and replaced by + `sshproxyctl show config` + - `sshproxyctl show hosts` and `sshproxyctl show users -all` now display + persist info + - `sshproxyctl enable HOST [PORT]` has been removed and replaced by + `sshproxyctl enable -all|-host HOST [-port PORT]` + - `sshproxyctl disable HOST [PORT]` has been removed and replaced by + `sshproxyctl disable -all|-host HOST [-port PORT]` + - `sshproxyctl forget HOST [PORT]` has been removed and replaced by + `sshproxyctl forget host -all|-host HOST [-port PORT]` + - `sshproxyctl error_banner` (without any parameter) has been removed and + replaced by `sshproxyctl forget error_banner` + Copying ------- diff --git a/cmd/sshproxyctl/sshproxyctl.go b/cmd/sshproxyctl/sshproxyctl.go index 3515def7..2827aa76 100644 --- a/cmd/sshproxyctl/sshproxyctl.go +++ b/cmd/sshproxyctl/sshproxyctl.go @@ -35,7 +35,6 @@ var ( // SshproxyVersion is set by Makefile SshproxyVersion = "0.0.0+noproperlybuilt" defaultConfig = "/etc/sshproxy/sshproxy.yaml" - defaultHostPort = "22" ) func mustInitEtcdClient(configFile string) *utils.Client { @@ -536,6 +535,13 @@ func setErrorBanner(errorBanner string, expire time.Time, configFile string) err return cli.SetErrorBanner(errorBanner, expire) } +func delErrorBanner(configFile string) error { + cli := mustInitEtcdClient(configFile) + defer cli.Close() + + return cli.DelErrorBanner() +} + func showErrorBanner(configFile string) { cli := mustInitEtcdClient(configFile) defer cli.Close() @@ -592,7 +598,7 @@ The commands are: version show version number and exit show show states present in etcd enable enable a host in etcd - forget forget a host in etcd + forget forget a host/error_banner in etcd disable disable a host in etcd error_banner set the error banner in etcd @@ -654,39 +660,53 @@ The options are: return fs } -func newEnableParser() *flag.FlagSet { +func newEnableParser(allFlag *bool, hostString *string, portString *string) *flag.FlagSet { fs := flag.NewFlagSet("enable", flag.ExitOnError) + fs.BoolVar(allFlag, "all", false, "enable all hosts present in config") + fs.StringVar(hostString, "host", "", "hostname to enable (can be a nodeset)") + fs.StringVar(portString, "port", "", "port to enable (can be a nodeset)") fs.Usage = func() { - fmt.Fprintf(flag.CommandLine.Output(), `Usage: %s enable HOST [PORT] + fmt.Fprintf(flag.CommandLine.Output(), `Usage: %s enable -all|-host HOST [-port PORT] -Enable a previously disabled host in etcd. The default port is %s. Host and port -can be nodesets. -`, os.Args[0], defaultHostPort) +Enable a previously disabled host in etcd. +`, os.Args[0]) + fs.PrintDefaults() os.Exit(2) } return fs } -func newForgetParser() *flag.FlagSet { +func newForgetParser(allFlag *bool, hostString *string, portString *string) *flag.FlagSet { fs := flag.NewFlagSet("forget", flag.ExitOnError) + fs.BoolVar(allFlag, "all", false, "forget all hosts present in config") + fs.StringVar(hostString, "host", "", "hostname to forget (can be a nodeset)") + fs.StringVar(portString, "port", "", "port to forget (can be a nodeset)") fs.Usage = func() { - fmt.Fprintf(flag.CommandLine.Output(), `Usage: %s forget HOST [PORT] + fmt.Fprintf(flag.CommandLine.Output(), `Usage: %s forget COMMAND [OPTIONS] -Forget a host in etcd. The default port is %s. Remember that if this host is -used, it will appear back in the list. Host and port can be nodesets. -`, os.Args[0], defaultHostPort) +The commands are: + host -all|-host HOST [-port PORT] forget a host in etcd + error_banner forget the error_banner in etcd + +The options are: +`, os.Args[0]) + fs.PrintDefaults() os.Exit(2) } return fs } -func newDisableParser() *flag.FlagSet { +func newDisableParser(allFlag *bool, hostString *string, portString *string) *flag.FlagSet { fs := flag.NewFlagSet("disable", flag.ExitOnError) + fs.BoolVar(allFlag, "all", false, "disable all hosts present in config") + fs.StringVar(hostString, "host", "", "hostname to disable (can be a nodeset)") + fs.StringVar(portString, "port", "", "port to disable (can be a nodeset)") fs.Usage = func() { - fmt.Fprintf(flag.CommandLine.Output(), `Usage: %s disable HOST [PORT] + fmt.Fprintf(flag.CommandLine.Output(), `Usage: %s disable -all|-host HOST [-port PORT] -Disable a host in etcd. The default port is %s. Host and port can be nodesets. -`, os.Args[0], defaultHostPort) +Disable a host in etcd. +`, os.Args[0]) + fs.PrintDefaults() os.Exit(2) } return fs @@ -708,53 +728,66 @@ The options are: return fs } -func getHostPortFromCommandLine(args []string) ([]string, []string, error) { +func getHostPortFromCommandLine(allFlag bool, hostsNodeset string, portsNodeset string, configFile string) ([]string, error) { _, nodesetDlclose, nodesetExpand := nodesets.InitExpander() defer nodesetDlclose() - hostsNodeset, portsNodeset := "", defaultHostPort - switch len(args) { - case 2: - hostsNodeset, portsNodeset = args[0], args[1] - case 1: - hostsNodeset = args[0] - default: - return []string{}, []string{}, fmt.Errorf("wrong number of arguments") - } - hosts, err := nodesetExpand(hostsNodeset) + configDests, err := utils.LoadAllDestsFromConfig(configFile) if err != nil { - return []string{}, []string{}, fmt.Errorf("%s", err) + return []string{}, fmt.Errorf("%s", err) } - ports, err := nodesetExpand(portsNodeset) - if err != nil { - return []string{}, []string{}, fmt.Errorf("%s", err) + + if allFlag && portsNodeset == "" { + return configDests, nil + } + + var hosts []string + var ports []string + for _, configDest := range configDests { + host, port, err := utils.SplitHostPort(configDest) + if err != nil { + return []string{}, fmt.Errorf("%s", err) + } + hosts = append(hosts, host) + ports = append(ports, port) + } + + if !allFlag { + hosts, err = nodesetExpand(hostsNodeset) + if err != nil { + return []string{}, fmt.Errorf("%s", err) + } } + + if portsNodeset != "" { + ports, err = nodesetExpand(portsNodeset) + if err != nil { + return []string{}, fmt.Errorf("%s", err) + } + } + + var hostPorts []string for _, port := range ports { if iport, err := strconv.Atoi(port); err != nil { - return []string{}, []string{}, fmt.Errorf("port \"%s\" must be an integer", port) + return []string{}, fmt.Errorf("port \"%s\" must be an integer", port) } else if iport < 0 || iport > 65535 { - return []string{}, []string{}, fmt.Errorf("port \"%s\" must be in the 0-65535 range", port) + return []string{}, fmt.Errorf("port \"%s\" must be in the 0-65535 range", port) } for _, host := range hosts { if _, _, err := net.SplitHostPort(host + ":" + port); err != nil { - return []string{}, []string{}, fmt.Errorf("%s", err) + return []string{}, fmt.Errorf("%s", err) } + hostPorts = append(hostPorts, host+":"+port) } } - return hosts, ports, nil + return hostPorts, nil } func getErrorBannerFromCommandLine(args []string) (string, error) { - errorBanner := "" - switch len(args) { - case 0: - errorBanner = "" - case 1: - errorBanner = args[0] - default: - return "", fmt.Errorf("wrong number of arguments") + if len(args) == 1 { + return args[0], nil } - return errorBanner, nil + return "", fmt.Errorf("wrong number of arguments") } func byteToHuman(b int, passthrough bool) string { @@ -831,14 +864,16 @@ func main() { var userString string var groupsString string var sourceString string + var hostString string + var portString string parsers := map[string]*flag.FlagSet{ "help": newHelpParser(), "version": newVersionParser(), "show": newShowParser(&csvFlag, &jsonFlag, &allFlag, &userString, &groupsString, &sourceString), - "enable": newEnableParser(), - "forget": newForgetParser(), - "disable": newDisableParser(), + "enable": newEnableParser(&allFlag, &hostString, &portString), + "forget": newForgetParser(&allFlag, &hostString, &portString), + "disable": newDisableParser(&allFlag, &hostString, &portString), "error_banner": newErrorBannerParser(&expire), } @@ -866,7 +901,7 @@ func main() { p := parsers[cmd] p.Parse(args) if p.NArg() == 0 { - fmt.Fprintf(os.Stderr, "ERROR: missing 'hosts' or 'connections'\n\n") + fmt.Fprintf(os.Stderr, "ERROR: missing 'hosts', 'connections', 'users', 'groups', 'error_banner' or 'config'\n\n") p.Usage() } subcmd := p.Arg(0) @@ -893,41 +928,75 @@ func main() { case "enable": p := parsers[cmd] p.Parse(args) - hosts, ports, err := getHostPortFromCommandLine(p.Args()) + if !allFlag && hostString == "" { + fmt.Fprintf(os.Stderr, "ERROR: missing '-all' or '-host'\n\n") + p.Usage() + } + hostPorts, err := getHostPortFromCommandLine(allFlag, hostString, portString, *configFile) if err != nil { fmt.Fprintf(os.Stderr, "ERROR: %s\n\n", err) p.Usage() } - for _, host := range hosts { - for _, port := range ports { - enableHost(host, port, *configFile) + for _, hostPort := range hostPorts { + host, port, err := utils.SplitHostPort(hostPort) + if err != nil { + fmt.Fprintf(os.Stderr, "ERROR: %s\n\n", err) + p.Usage() } + enableHost(host, port, *configFile) } case "forget": p := parsers[cmd] p.Parse(args) - hosts, ports, err := getHostPortFromCommandLine(p.Args()) - if err != nil { - fmt.Fprintf(os.Stderr, "ERROR: %s\n\n", err) + if p.NArg() == 0 { + fmt.Fprintf(os.Stderr, "ERROR: missing 'host' or 'error_banner'\n\n") p.Usage() } - for _, host := range hosts { - for _, port := range ports { + subcmd := p.Arg(0) + // parse flags after subcommand + args = p.Args()[1:] + p.Parse(args) + switch subcmd { + case "host": + if !allFlag && hostString == "" { + fmt.Fprintf(os.Stderr, "ERROR: missing '-all' or '-host'\n\n") + p.Usage() + } + hostPorts, err := getHostPortFromCommandLine(allFlag, hostString, portString, *configFile) + if err != nil { + fmt.Fprintf(os.Stderr, "ERROR: %s\n\n", err) + p.Usage() + } + for _, hostPort := range hostPorts { + host, port, err := utils.SplitHostPort(hostPort) + if err != nil { + fmt.Fprintf(os.Stderr, "ERROR: %s\n\n", err) + p.Usage() + } forgetHost(host, port, *configFile) } + case "error_banner": + delErrorBanner(*configFile) } case "disable": p := parsers[cmd] p.Parse(args) - hosts, ports, err := getHostPortFromCommandLine(p.Args()) + if !allFlag && hostString == "" { + fmt.Fprintf(os.Stderr, "ERROR: missing '-all' or '-host'\n\n") + p.Usage() + } + hostPorts, err := getHostPortFromCommandLine(allFlag, hostString, portString, *configFile) if err != nil { fmt.Fprintf(os.Stderr, "ERROR: %s\n\n", err) p.Usage() } - for _, host := range hosts { - for _, port := range ports { - disableHost(host, port, *configFile) + for _, hostPort := range hostPorts { + host, port, err := utils.SplitHostPort(hostPort) + if err != nil { + fmt.Fprintf(os.Stderr, "ERROR: %s\n\n", err) + p.Usage() } + disableHost(host, port, *configFile) } case "error_banner": p := parsers[cmd] diff --git a/doc/sshproxyctl.txt b/doc/sshproxyctl.txt index 60d8cc71..0aa31ab8 100644 --- a/doc/sshproxyctl.txt +++ b/doc/sshproxyctl.txt @@ -39,29 +39,40 @@ COMMANDS *version*:: Show version number and exit. -*enable HOST [PORT]*:: - Enable a destination host in etcd if the host was previously disabled by - the 'disable' command (see below). The port by default is 22 if not - specified. Host and port can be nodesets. If libnodeset.so is - available, clustershell groups can also be used. - -*disable HOST [PORT]*:: +*enable -all|-host HOST [-port PORT]*:: + Enable a destination host in etcd if the host was previously disabled + by the 'disable' command (see below). If '-all' is specified instead + of '-host', all the hosts present in the configuration are enabled. If + no '-port' is specified, all the ports present in the configuration are + used. 'HOST' and 'PORT' can be nodesets. If libnodeset.so (from + https://github.com/fdiakh/nodeset-rs) is available, clustershell groups + can also be used. + +*disable -all|-host HOST [-port PORT]*:: Disable a destination host in etcd. A disabled host will not be proposed as a destination. The only way to enable it again is to send - the 'enable' command. It could be used for host maintenance. The port - by default is 22 if not specified. Host and port can be nodesets. If - libnodeset.so is available, clustershell groups can also be used. - -*forget HOST [PORT]*:: + the 'enable' command. It could be used for host maintenance. If + '-all' is specified instead of '-host', all the hosts present in the + configuration are enabled. If no '-port' is specified, all the ports + present in the configuration are used. 'HOST' and 'PORT' can be + nodesets. If libnodeset.so (from https://github.com/fdiakh/nodeset-rs) + is available, clustershell groups can also be used. + +*forget host -all|-host HOST [-port PORT]*:: Forget a host in etcd. Remember that if this host is used, it will - appear back in the list. The port by default is 22 if not specified. - Host and port can be nodesets. If libnodeset.so is available, - clustershell groups can also be used. + appear back in the list. If '-all' is specified instead of '-host', + all the hosts present in the configuration are forgotten. If no '-port' + is specified, all the ports present in the configuration are used. + 'HOST' and 'PORT' can be nodesets. If libnodeset.so (from + https://github.com/fdiakh/nodeset-rs) is available, clustershell groups + can also be used. + +*forget error_banner*:: + Remove the error banner in etcd. *error_banner [-expire EXPIRATION] MESSAGE*:: - Set the error banner in etcd. Removes the error banner in etcd if - 'MESSAGE' is absent. 'MESSAGE' can be multiline. The error banner is - displayed to the client when no backend can be reached (more + Set the error banner in etcd. 'MESSAGE' can be multiline. The error + banner is displayed to the client when no backend can be reached (more precisely, when all backends are either down or disabled in etcd). '-expire' sets the expiration date of this error banner. Format: 'YYYY-MM-DD[ HH:MM[:SS]]' diff --git a/misc/sshproxyctl-completion.bash b/misc/sshproxyctl-completion.bash index be7aeb48..49772ca9 100644 --- a/misc/sshproxyctl-completion.bash +++ b/misc/sshproxyctl-completion.bash @@ -9,6 +9,18 @@ _sshproxyctl() { opts="-h -c ${commands}" case "${prev}" in + disable) + COMPREPLY=( $(compgen -W '-all -host -port' -- "${cur}") ) + ;; + enable) + COMPREPLY=( $(compgen -W '-all -host -port' -- "${cur}") ) + ;; + error_banner) + COMPREPLY=( $(compgen -W '-expire' -- "${cur}") ) + ;; + forget) + COMPREPLY=( $(compgen -W '-all -host -port host error_banner' -- "${cur}") ) + ;; help) COMPREPLY=( $(compgen -W "${commands}" -- "${cur}") ) ;; @@ -18,6 +30,9 @@ _sshproxyctl() { connections) COMPREPLY=( $(compgen -W '-all -csv -json' -- "${cur}") ) ;; + host) + COMPREPLY=( $(compgen -W '-all -host -port' -- "${cur}") ) + ;; hosts) COMPREPLY=( $(compgen -W '-csv -json' -- "${cur}") ) ;; @@ -30,27 +45,30 @@ _sshproxyctl() { config) COMPREPLY=( $(compgen -W '-user -groups -source' -- "${cur}") ) ;; - error_banner) - COMPREPLY=( $(compgen -W '-expire' -- "${cur}") ) - ;; -all) - COMPREPLY=( $(compgen -W '-csv -json connections users groups' -- "${cur}") ) + COMPREPLY=( $(compgen -W '-csv -json -port connections users groups' -- "${cur}") ) ;; -csv) COMPREPLY=( $(compgen -W '-all connections hosts users groups' -- "${cur}") ) ;; + -groups) + COMPREPLY=( $(compgen -W '-user -source config' -- "${cur}") ) + ;; + -host) + COMPREPLY=( $(compgen -W '-port' -- "${cur}") ) + ;; -json) COMPREPLY=( $(compgen -W '-all connections hosts users groups' -- "${cur}") ) ;; - -user) - COMPREPLY=( $(compgen -W '-groups -source config' -- "${cur}") ) - ;; - -groups) - COMPREPLY=( $(compgen -W '-user -source config' -- "${cur}") ) + -port) + COMPREPLY=( $(compgen -W '-all -host' -- "${cur}") ) ;; -source) COMPREPLY=( $(compgen -W '-user -groups config' -- "${cur}") ) ;; + -user) + COMPREPLY=( $(compgen -W '-groups -source config' -- "${cur}") ) + ;; -c) _filedir ;; diff --git a/pkg/utils/config.go b/pkg/utils/config.go index 53b9aa01..fd689e9b 100644 --- a/pkg/utils/config.go +++ b/pkg/utils/config.go @@ -281,6 +281,24 @@ func replace(src string, replacer *patternReplacer) string { return replacer.Regexp.ReplaceAllString(src, replacer.Text) } +// LoadAllDestsFromConfig loads configuration file and returns all defined destinations. +func LoadAllDestsFromConfig(filename string) ([]string, error) { + yamlFile, err := os.ReadFile(filename) + if err != nil { + return nil, err + } + var config Config + if err := yaml.Unmarshal(yamlFile, &config); err != nil { + return nil, err + } + for _, override := range config.Overrides { + if override.Dest != nil { + config.Dest = append(config.Dest, override.Dest...) + } + } + return config.Dest, nil +} + // LoadConfig load configuration file and adapt it according to specified user/group/sshdHostPort. func LoadConfig(filename, currentUsername, sid string, start time.Time, groups map[string]bool, sshdHostPort string) (*Config, error) { if cachedConfig.ready { diff --git a/test/fedora-image/sshproxy_test.go b/test/fedora-image/sshproxy_test.go index 73d7db57..206fcb4d 100644 --- a/test/fedora-image/sshproxy_test.go +++ b/test/fedora-image/sshproxy_test.go @@ -156,8 +156,7 @@ type aggConnection struct { } func getEtcdConnections() ([]aggConnection, string) { - ctx := context.Background() - _, stdout, _, err := runCommand(ctx, "ssh", []string{"gateway1", "--", fmt.Sprintf("%s show -json connections", SSHPROXYCTL)}, nil, nil) + _, stdout, _, err := etcdCommand("show -json connections") if err != nil { log.Fatal(err) } @@ -179,8 +178,7 @@ type host struct { } func getEtcdHosts() ([]host, string) { - ctx := context.Background() - _, stdout, _, err := runCommand(ctx, "ssh", []string{"gateway1", "--", fmt.Sprintf("%s show -json hosts", SSHPROXYCTL)}, nil, nil) + _, stdout, _, err := etcdCommand("show -json hosts") if err != nil { log.Fatal(err) } @@ -195,27 +193,82 @@ func getEtcdHosts() ([]host, string) { } func disableHost(host string) { - ctx := context.Background() - _, _, _, err := runCommand(ctx, "ssh", []string{"gateway1", "--", fmt.Sprintf("%s disable %s", SSHPROXYCTL, host)}, nil, nil) + _, _, _, err := etcdCommand(fmt.Sprintf("disable %s", host)) if err != nil { log.Fatal(err) } } func enableHost(host string) { - ctx := context.Background() - _, _, _, err := runCommand(ctx, "ssh", []string{"gateway1", "--", fmt.Sprintf("%s enable %s", SSHPROXYCTL, host)}, nil, nil) + _, _, _, err := etcdCommand(fmt.Sprintf("enable %s", host)) if err != nil { log.Fatal(err) } } func forgetHost(host string) error { - ctx := context.Background() - _, _, _, err := runCommand(ctx, "ssh", []string{"gateway1", "--", fmt.Sprintf("%s forget %s", SSHPROXYCTL, host)}, nil, nil) + _, _, _, err := etcdCommand(fmt.Sprintf("forget host %s", host)) return err } +type user struct { + User string + Group string + Service string + N int +} + +func getEtcdUsers(mode string, allFlag bool) (map[string]user, string) { + all := "" + if allFlag { + all = " -all" + } + _, stdout, _, err := etcdCommand(fmt.Sprintf("show -json %s%s", mode, all)) + if err != nil { + log.Fatal(err) + } + + jsonStr := strings.TrimSpace(string(stdout)) + var users []user + if err := json.Unmarshal(stdout, &users); err != nil { + log.Fatal(err) + } + usersMap := map[string]user{} + for _, user := range users { + key := "" + if user.User != "" { + key = user.User + } else { + key = user.Group + } + if allFlag { + key += "@" + user.Service + } + usersMap[key] = user + } + + return usersMap, jsonStr +} + +func setEtcdErrorBanner(banner string) { + _, _, _, err := etcdCommand(fmt.Sprintf("error_banner '%s'", banner)) + if err != nil { + log.Fatal(err) + } +} + +func forgetEtcdErrorBanner() { + _, _, _, err := etcdCommand("forget error_banner") + if err != nil { + log.Fatal(err) + } +} + +func etcdCommand(command string) (int, []byte, []byte, error) { + ctx := context.Background() + return runCommand(ctx, "ssh", []string{"gateway1", "--", fmt.Sprintf("%s %s", SSHPROXYCTL, command)}, nil, nil) +} + var simpleConnectTests = []struct { user string port int @@ -269,34 +322,34 @@ func TestBlockingCommand(t *testing.T) { } func TestNodesets(t *testing.T) { - disableHost("server[1000-1002]") + disableHost("-host server[1000-1002]") checkHostState(t, "server1000:22", "disabled", true) checkHostState(t, "server1001:22", "disabled", true) checkHostState(t, "server1002:22", "disabled", true) - enableHost("server[1000-1002]") + enableHost("-host server[1000-1002]") checkHostState(t, "server1000:22", "up", true) checkHostState(t, "server1001:22", "up", true) checkHostState(t, "server1002:22", "up", true) - err := forgetHost("server[1001]") + err := forgetHost("-host server[1001]") if err != nil { t.Errorf("got %s, expected no error", err) } checkHostState(t, "server1000:22", "up", true) checkHostState(t, "server1001:22", "", false) checkHostState(t, "server1002:22", "up", true) - err = forgetHost("server[1000-1002]") + err = forgetHost("-host server[1000-1002]") if err != nil { t.Errorf("got %s, expected no error", err) } checkHostState(t, "server1000:22", "", false) checkHostState(t, "server1001:22", "", false) checkHostState(t, "server1002:22", "", false) - err = forgetHost("server[12345]") + err = forgetHost("-host server[12345]") if err != nil { t.Errorf("got %s, expected no error", err) } checkHostState(t, "server12345:22", "", false) - if forgetHost("server[notAnumber]") == nil { + if forgetHost("-host server[notAnumber]") == nil { t.Errorf("got no error, expected error due to notAnumber not being a number") } } @@ -460,7 +513,7 @@ func TestStickyConnections(t *testing.T) { // remove old connections stored in etcd time.Sleep(4 * time.Second) - disableHost("server1") + disableHost("-host server1") checkHostState(t, "server1:22", "disabled", true) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) @@ -473,7 +526,7 @@ func TestStickyConnections(t *testing.T) { process1 := <-ch time.Sleep(time.Second) - enableHost("server1") + enableHost("-host server1") checkHostState(t, "server1:22", "up", true) args, cmdStr := prepareCommand("gateway2", 2022, "hostname") @@ -492,7 +545,7 @@ func TestNotLongStickyConnections(t *testing.T) { // remove old connections stored in etcd time.Sleep(4 * time.Second) - disableHost("server1") + disableHost("-host server1") checkHostState(t, "server1:22", "disabled", true) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) @@ -504,7 +557,7 @@ func TestNotLongStickyConnections(t *testing.T) { } time.Sleep(2 * time.Second) - enableHost("server1") + enableHost("-host server1") checkHostState(t, "server1:22", "up", true) args, cmdStr := prepareCommand("gateway2", 2022, "hostname") @@ -523,7 +576,7 @@ func TestLongStickyConnections(t *testing.T) { time.Sleep(4 * time.Second) updateLineSSHProxyConf("etcd_keyttl", "10") - disableHost("server1") + disableHost("-host server1") checkHostState(t, "server1:22", "disabled", true) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) @@ -535,7 +588,7 @@ func TestLongStickyConnections(t *testing.T) { } time.Sleep(2 * time.Second) - enableHost("server1") + enableHost("-host server1") checkHostState(t, "server1:22", "up", true) args, cmdStr := prepareCommand("gateway2", 2022, "hostname") @@ -662,7 +715,7 @@ func TestEnableDisableHost(t *testing.T) { t.Errorf("%s got %s, expected server1", cmdStr, dest) } - disableHost("server[1,100]") + disableHost("-host server[1,100]") checkHostState(t, "server1:22", "disabled", true) _, stdout, _, err = runCommand(ctx, "ssh", args, nil, nil) @@ -674,7 +727,7 @@ func TestEnableDisableHost(t *testing.T) { t.Errorf("%s got %s, expected server2", cmdStr, dest) } - enableHost("server1") + enableHost("-host server1") checkHostState(t, "server1:22", "up", true) // test stickiness @@ -699,46 +752,6 @@ func TestEnableDisableHost(t *testing.T) { } } -type user struct { - User string - Group string - Service string - N int -} - -func getEtcdUsers(mode string, allFlag bool) (map[string]user, string) { - all := "" - if allFlag { - all = " -all" - } - ctx := context.Background() - _, stdout, _, err := runCommand(ctx, "ssh", []string{"gateway1", "--", fmt.Sprintf("%s show -json %s%s", SSHPROXYCTL, mode, all)}, nil, nil) - if err != nil { - log.Fatal(err) - } - - jsonStr := strings.TrimSpace(string(stdout)) - var users []user - if err := json.Unmarshal(stdout, &users); err != nil { - log.Fatal(err) - } - usersMap := map[string]user{} - for _, user := range users { - key := "" - if user.User != "" { - key = user.User - } else { - key = user.Group - } - if allFlag { - key += "@" + user.Service - } - usersMap[key] = user - } - - return usersMap, jsonStr -} - func TestEtcdUsers(t *testing.T) { // remove old connections stored in etcd time.Sleep(4 * time.Second) @@ -965,6 +978,40 @@ func TestSCP(t *testing.T) { } } +func compareSshToErrorBanner(errorBanner string) string { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + args, _ := prepareCommand("gateway1", 2023, "hostname") + _, stdout, _, err := runCommand(ctx, "ssh", args, nil, nil) + stdoutStr := strings.TrimSpace(string(stdout)) + if err == nil { + return "Expected error because all hosts are disabled, got no error" + } else if stdoutStr != errorBanner { + return fmt.Sprintf("got error = %s, want %s", stdoutStr, errorBanner) + } + return "" +} + +func TestErrorBanner(t *testing.T) { + disableHost("-all") + defer enableHost("-all") + defaultError := "a default error" + line := fmt.Sprintf("error_banner: %s", defaultError) + addLineSSHProxyConf(line) + defer removeLineSSHProxyConf(line) + + customError := "a custom error" + setEtcdErrorBanner(customError) + if errStr := compareSshToErrorBanner(customError); errStr != "" { + t.Error(errStr) + } + + forgetEtcdErrorBanner() + if errStr := compareSshToErrorBanner(defaultError); errStr != "" { + t.Error(errStr) + } +} + func waitForServers(hostports []string, timeout time.Duration) { results := make([]bool, len(hostports)) ticker := time.NewTicker(time.Second)