Skip to content

Commit 04d998d

Browse files
committed
add tests for config.go and resolve a few config bugs
1 parent 9701629 commit 04d998d

20 files changed

+551
-15
lines changed

cmd/sshproxy/sshproxy.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ func mainExitCode() int {
296296
}
297297
}
298298
} else {
299-
if config.Etcd.Mandatory {
299+
if config.Etcd.Mandatory.(bool) {
300300
log.Fatal("Etcd is mandatory but unavailable")
301301
}
302302
}

config/sshproxy.yaml

+1-4
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,7 @@
204204
# ENV1: /tmp/env
205205
# ssh:
206206
# args: ["-vvv", "-Y"]
207-
# # If routes are specified, each specified route is fully overridden, not merged.
208-
# routes:
209-
# default:
210-
# dest: [hostx]
207+
# dest: [hostx]
211208
# - match:
212209
# - groups: [bar]
213210
# groups: [baz]

pkg/utils/config.go

+71-10
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"os"
1717
"regexp"
1818
"slices"
19+
"sort"
1920
"strings"
2021
"time"
2122

@@ -82,13 +83,15 @@ type sshConfig struct {
8283
Args []string `yaml:",flow,omitempty"`
8384
}
8485

86+
// We use interface{} instead of real type to check if the option was specified
87+
// or not.
8588
type etcdConfig struct {
8689
Endpoints []string `yaml:",flow"`
8790
TLS etcdTLSConfig `yaml:",omitempty"`
8891
Username string `yaml:",omitempty"`
8992
Password string `yaml:",omitempty"`
9093
KeyTTL int64 `yaml:",omitempty"`
91-
Mandatory bool `yaml:",omitempty"`
94+
Mandatory interface{} `yaml:",omitempty"`
9295
}
9396

9497
type etcdTLSConfig struct {
@@ -108,12 +111,12 @@ type subConfig struct {
108111
Dump interface{} `yaml:",omitempty"`
109112
DumpLimitSize interface{} `yaml:"dump_limit_size,omitempty"`
110113
DumpLimitWindow interface{} `yaml:"dump_limit_window,omitempty"`
111-
Etcd interface{} `yaml:",omitempty"`
114+
Etcd etcdConfig `yaml:",omitempty"`
112115
EtcdStatsInterval interface{} `yaml:"etcd_stats_interval,omitempty"`
113116
LogStatsInterval interface{} `yaml:"log_stats_interval,omitempty"`
114117
BlockingCommand interface{} `yaml:"blocking_command,omitempty"`
115118
BgCommand interface{} `yaml:"bg_command,omitempty"`
116-
SSH interface{} `yaml:",omitempty"`
119+
SSH sshConfig `yaml:",omitempty"`
117120
TranslateCommands map[string]*TranslateCommandConfig `yaml:"translate_commands,omitempty"`
118121
Environment map[string]string `yaml:",omitempty"`
119122
Service interface{} `yaml:",omitempty"`
@@ -143,8 +146,15 @@ func PrintConfig(config *Config, groups map[string]bool) []string {
143146
output = append(output, fmt.Sprintf("config.blocking_command = %s", config.BlockingCommand))
144147
output = append(output, fmt.Sprintf("config.bg_command = %s", config.BgCommand))
145148
output = append(output, fmt.Sprintf("config.ssh = %+v", config.SSH))
146-
for k, v := range config.TranslateCommands {
147-
output = append(output, fmt.Sprintf("config.TranslateCommands.%s = %+v", k, v))
149+
// Internally, we don't care of TranslateCommands's order. But we want to
150+
// always display it in the same order
151+
keys := make([]string, 0, len(config.TranslateCommands))
152+
for k := range config.TranslateCommands {
153+
keys = append(keys, k)
154+
}
155+
sort.Strings(keys)
156+
for _, k := range keys {
157+
output = append(output, fmt.Sprintf("config.TranslateCommands.%s = %+v", k, config.TranslateCommands[k]))
148158
}
149159
output = append(output, fmt.Sprintf("config.environment = %v", config.Environment))
150160
output = append(output, fmt.Sprintf("config.service = %s", config.Service))
@@ -168,6 +178,9 @@ func parseSubConfig(config *Config, subconfig *subConfig) error {
168178
}
169179

170180
if subconfig.CheckInterval != nil {
181+
if fmt.Sprintf("%T", subconfig.CheckInterval) != "string" {
182+
return fmt.Errorf("check_interval: %v is not a string", subconfig.CheckInterval)
183+
}
171184
var err error
172185
config.CheckInterval, err = time.ParseDuration(subconfig.CheckInterval.(string))
173186
if err != nil {
@@ -188,18 +201,52 @@ func parseSubConfig(config *Config, subconfig *subConfig) error {
188201
}
189202

190203
if subconfig.DumpLimitWindow != nil {
204+
if fmt.Sprintf("%T", subconfig.DumpLimitWindow) != "string" {
205+
return fmt.Errorf("dump_limit_window: %v is not a string", subconfig.DumpLimitWindow)
206+
}
191207
var err error
192208
config.DumpLimitWindow, err = time.ParseDuration(subconfig.DumpLimitWindow.(string))
193209
if err != nil {
194210
return err
195211
}
196212
}
197213

198-
if subconfig.Etcd != nil {
199-
config.Etcd = subconfig.Etcd.(etcdConfig)
214+
if subconfig.Etcd.Endpoints != nil {
215+
config.Etcd.Endpoints = subconfig.Etcd.Endpoints
216+
}
217+
218+
if subconfig.Etcd.TLS.CAFile != "" {
219+
config.Etcd.TLS.CAFile = subconfig.Etcd.TLS.CAFile
220+
}
221+
222+
if subconfig.Etcd.TLS.KeyFile != "" {
223+
config.Etcd.TLS.KeyFile = subconfig.Etcd.TLS.KeyFile
224+
}
225+
226+
if subconfig.Etcd.TLS.CertFile != "" {
227+
config.Etcd.TLS.CertFile = subconfig.Etcd.TLS.CertFile
228+
}
229+
230+
if subconfig.Etcd.Username != "" {
231+
config.Etcd.Username = subconfig.Etcd.Username
232+
}
233+
234+
if subconfig.Etcd.Password != "" {
235+
config.Etcd.Password = subconfig.Etcd.Password
236+
}
237+
238+
if subconfig.Etcd.KeyTTL != 0 {
239+
config.Etcd.KeyTTL = subconfig.Etcd.KeyTTL
240+
}
241+
242+
if subconfig.Etcd.Mandatory != nil {
243+
config.Etcd.Mandatory = subconfig.Etcd.Mandatory
200244
}
201245

202246
if subconfig.EtcdStatsInterval != nil {
247+
if fmt.Sprintf("%T", subconfig.EtcdStatsInterval) != "string" {
248+
return fmt.Errorf("etcd_stats_interval: %v is not a string", subconfig.EtcdStatsInterval)
249+
}
203250
var err error
204251
config.EtcdStatsInterval, err = time.ParseDuration(subconfig.EtcdStatsInterval.(string))
205252
if err != nil {
@@ -208,6 +255,9 @@ func parseSubConfig(config *Config, subconfig *subConfig) error {
208255
}
209256

210257
if subconfig.LogStatsInterval != nil {
258+
if fmt.Sprintf("%T", subconfig.LogStatsInterval) != "string" {
259+
return fmt.Errorf("log_stats_interval: %v is not a string", subconfig.LogStatsInterval)
260+
}
211261
var err error
212262
config.LogStatsInterval, err = time.ParseDuration(subconfig.LogStatsInterval.(string))
213263
if err != nil {
@@ -223,8 +273,12 @@ func parseSubConfig(config *Config, subconfig *subConfig) error {
223273
config.BgCommand = subconfig.BgCommand.(string)
224274
}
225275

226-
if subconfig.SSH != nil {
227-
config.SSH = subconfig.SSH.(sshConfig)
276+
if subconfig.SSH.Exe != "" {
277+
config.SSH.Exe = subconfig.SSH.Exe
278+
}
279+
280+
if subconfig.SSH.Args != nil {
281+
config.SSH.Args = subconfig.SSH.Args
228282
}
229283

230284
// merge translate_commands
@@ -296,7 +350,14 @@ func LoadAllDestsFromConfig(filename string) ([]string, error) {
296350
config.Dest = append(config.Dest, override.Dest...)
297351
}
298352
}
299-
return config.Dest, nil
353+
// expand destination nodesets
354+
_, nodesetDlclose, nodesetExpand := nodesets.InitExpander()
355+
defer nodesetDlclose()
356+
dsts, err := nodesetExpand(strings.Join(config.Dest, ","))
357+
if err != nil {
358+
return nil, fmt.Errorf("invalid nodeset: %s", err)
359+
}
360+
return dsts, nil
300361
}
301362

302363
// LoadConfig load configuration file and adapt it according to specified user/group/sshdHostPort.

pkg/utils/config_test.go

+190
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
// Copyright 2015-2025 CEA/DAM/DIF
2+
// Author: Arnaud Guignard <[email protected]>
3+
// Contributor: Cyril Servant <[email protected]>
4+
//
5+
// This software is governed by the CeCILL-B license under French law and
6+
// abiding by the rules of distribution of free software. You can use,
7+
// modify and/ or redistribute the software under the terms of the CeCILL-B
8+
// license as circulated by CEA, CNRS and INRIA at the following URL
9+
// "http://www.cecill.info".
10+
11+
package utils
12+
13+
import (
14+
"reflect"
15+
"sort"
16+
"testing"
17+
"time"
18+
)
19+
20+
var start time.Time = time.Now()
21+
22+
var loadConfigTests = []struct {
23+
filename, username string
24+
want []string
25+
err string
26+
}{
27+
{"nonexistingfile.yaml", "alice", []string{}, "open nonexistingfile.yaml: no such file or directory"},
28+
{"../../test/configEmpty.yaml", "alice", []string{}, "no destination defined for service 'default'"},
29+
{"../../test/configInvalid.yaml", "alice", []string{}, "yaml: found character that cannot start any token"},
30+
{"../../test/configCheckIntervalError.yaml", "alice", []string{}, `time: invalid duration "not a duration"`},
31+
{"../../test/configCheckIntervalNotString.yaml", "alice", []string{}, "check_interval: 10 is not a string"},
32+
{"../../test/configDumpLimitWindowError.yaml", "alice", []string{}, `time: invalid duration "not a duration"`},
33+
{"../../test/configDumpLimitWindowNotString.yaml", "alice", []string{}, "dump_limit_window: 10 is not a string"},
34+
{"../../test/configEtcdStatsIntervalError.yaml", "alice", []string{}, `time: invalid duration "not a duration"`},
35+
{"../../test/configEtcdStatsIntervalNotString.yaml", "alice", []string{}, "etcd_stats_interval: 10 is not a string"},
36+
{"../../test/configLogStatsIntervalError.yaml", "alice", []string{}, `time: invalid duration "not a duration"`},
37+
{"../../test/configLogStatsIntervalNotString.yaml", "alice", []string{}, "log_stats_interval: 10 is not a string"},
38+
{"../../test/configMatchSourceError.yaml", "alice", []string{}, "source: invalid address: address 127.0.0.1:abcd: invalid port"},
39+
{"../../test/configRouteSelectError.yaml", "alice", []string{}, "invalid value for `route_select` option of service 'default': notarouteselect"},
40+
{"../../test/configModeError.yaml", "alice", []string{}, "invalid value for `mode` option of service 'default': notamode"},
41+
// yes, "cannont" is an upstream typo
42+
{"../../test/configDestNodesetError.yaml", "alice", []string{}, "invalid nodeset for service 'default': cannont convert ending range to integer a - rangeset parse error"},
43+
{"../../test/configDestError.yaml", "alice", []string{}, "invalid destination '127.0.0.1:abcd' for service 'default': address 127.0.0.1:abcd: invalid port"},
44+
{"../../test/configDefault.yaml", "alice", []string{
45+
"libnodeset.so not found, falling back to iskylite's implementation",
46+
"groups = map[bar:true foo:true]",
47+
"config.debug = false",
48+
"config.log = ",
49+
"config.check_interval = 0s",
50+
"config.error_banner = ",
51+
"config.dump = ",
52+
"config.dump_limit_size = 0",
53+
"config.dump_limit_window = 0s",
54+
"config.etcd = {Endpoints:[] TLS:{CAFile: KeyFile: CertFile:} Username: Password: KeyTTL:0 Mandatory:<nil>}",
55+
"config.etcd_stats_interval = 0s",
56+
"config.log_stats_interval = 0s",
57+
"config.blocking_command = ",
58+
"config.bg_command = ",
59+
"config.ssh = {Exe:ssh Args:[-q -Y]}",
60+
"config.environment = map[]",
61+
"config.service = default",
62+
"config.dest = [127.0.0.1:22]",
63+
"config.route_select = ordered",
64+
"config.mode = sticky",
65+
"config.force_command = ",
66+
"config.command_must_match = false",
67+
"config.etcd_keyttl = 0",
68+
"config.max_connections_per_user = 0",
69+
}, ""},
70+
{"../../test/config.yaml", "alice", []string{
71+
"libnodeset.so not found, falling back to iskylite's implementation",
72+
"groups = map[bar:true foo:true]",
73+
"config.debug = true",
74+
"config.log = /tmp/sshproxy-foo/alice.log",
75+
"config.check_interval = 2m0s",
76+
"config.error_banner = an other error banner",
77+
"config.dump = /tmp/sshproxy-alice-" + start.Format(time.RFC3339Nano) + ".dump",
78+
"config.dump_limit_size = 20",
79+
"config.dump_limit_window = 3m0s",
80+
"config.etcd = {Endpoints:[host2] TLS:{CAFile:ca2.pem KeyFile:cert2.key CertFile:cert2.pem} Username:test2 Password:pass2 KeyTTL:2 Mandatory:false}",
81+
"config.etcd_stats_interval = 4m0s",
82+
"config.log_stats_interval = 5m0s",
83+
"config.blocking_command = /a/blocking/command",
84+
"config.bg_command = /a/background/command",
85+
"config.ssh = {Exe:sshhhhh Args:[-vvv -Y]}",
86+
"config.TranslateCommands.acommand = &{SSHArgs:[] Command:something DisableDump:true}",
87+
"config.TranslateCommands.internal-sftp = &{SSHArgs:[-s] Command:anothercommand DisableDump:false}",
88+
"config.environment = map[ENV1:/tmp/env XAUTHORITY:/dev/shm/.Xauthority_alice]",
89+
"config.service = service5",
90+
"config.dest = [server1:12345]",
91+
"config.route_select = bandwidth",
92+
"config.mode = balanced",
93+
"config.force_command = acommand",
94+
"config.command_must_match = false",
95+
"config.etcd_keyttl = 0",
96+
"config.max_connections_per_user = 0",
97+
}, ""},
98+
{"../../test/config.yaml", "notalice", []string{
99+
"libnodeset.so not found, falling back to iskylite's implementation",
100+
"groups = map[bar:true foo:true]",
101+
"config.debug = true",
102+
"config.log = /var/log/sshproxy/notalice.log",
103+
"config.check_interval = 2m30s",
104+
"config.error_banner = an error banner",
105+
"config.dump = /var/lib/sshproxy/dumps/notalice/" + start.Format(time.RFC3339Nano) + "-abcd.dump",
106+
"config.dump_limit_size = 10",
107+
"config.dump_limit_window = 2m31s",
108+
"config.etcd = {Endpoints:[host1:port1 host2:port2] TLS:{CAFile:ca.pem KeyFile:cert.key CertFile:cert.pem} Username:test Password:pass KeyTTL:5 Mandatory:true}",
109+
"config.etcd_stats_interval = 2m33s",
110+
"config.log_stats_interval = 2m32s",
111+
"config.blocking_command = ",
112+
"config.bg_command = ",
113+
"config.ssh = {Exe:ssh Args:[-q -Y]}",
114+
"config.TranslateCommands.internal-sftp = &{SSHArgs:[-oForwardX11=no -oForwardAgent=no -oPermitLocalCommand=no -oClearAllForwardings=yes -oProtocol=2 -s] Command:sftp DisableDump:true}",
115+
"config.environment = map[XAUTHORITY:/dev/shm/.Xauthority_notalice]",
116+
"config.service = default",
117+
"config.dest = [host5:4222]",
118+
"config.route_select = ordered",
119+
"config.mode = sticky",
120+
"config.force_command = ",
121+
"config.command_must_match = true",
122+
"config.etcd_keyttl = 3600",
123+
"config.max_connections_per_user = 50",
124+
}, ""},
125+
}
126+
127+
func TestLoadConfig(t *testing.T) {
128+
sid := "abcd"
129+
groups := map[string]bool{"foo": true, "bar": true}
130+
sshdHostPort := "127.0.0.1:22"
131+
for _, tt := range loadConfigTests {
132+
cachedConfig = Config{}
133+
config, err := LoadConfig(tt.filename, tt.username, sid, start, groups, sshdHostPort)
134+
if err == nil && tt.err != "" {
135+
t.Errorf("got no error, want %s", tt.err)
136+
} else if err != nil && err.Error() != tt.err {
137+
t.Errorf("ERROR: %s, want %s", err, tt.err)
138+
} else if err == nil && !reflect.DeepEqual(PrintConfig(config, groups), tt.want) {
139+
t.Errorf("want:\n%v\ngot:\n%v", tt.want, PrintConfig(config, groups))
140+
} else if err == nil {
141+
cachedConfig, err := LoadConfig(tt.filename, tt.username, sid, start, groups, sshdHostPort)
142+
if err != nil {
143+
t.Errorf("ERROR: %s", err)
144+
} else if config != cachedConfig {
145+
t.Error("config and cachedConfig should be the same")
146+
}
147+
}
148+
}
149+
}
150+
151+
var loadAllDestsFromConfigTests = []struct {
152+
filename string
153+
want []string
154+
err string
155+
}{
156+
{"nonexistingfile.yaml", []string{}, "open nonexistingfile.yaml: no such file or directory"},
157+
{"../../test/configInvalid.yaml", []string{}, "yaml: found character that cannot start any token"},
158+
{"../../test/configDestNodesetError.yaml", []string{}, "invalid nodeset: cannont convert ending range to integer a - rangeset parse error"},
159+
{"../../test/configDefault.yaml", []string{"127.0.0.1"}, ""},
160+
{"../../test/config.yaml", []string{
161+
"192.168.0.1",
162+
"192.168.0.2",
163+
"192.168.0.3",
164+
"192.168.0.4",
165+
"192.168.0.5",
166+
"192.168.0.6",
167+
"192.168.0.7",
168+
"192.168.0.8",
169+
"192.168.0.9",
170+
"192.168.0.10",
171+
"host5:4222",
172+
"server1:12345",
173+
}, ""},
174+
}
175+
176+
func TestLoadAllDestsFromConfig(t *testing.T) {
177+
for _, tt := range loadAllDestsFromConfigTests {
178+
config, err := LoadAllDestsFromConfig(tt.filename)
179+
if err != nil {
180+
sort.Strings(config)
181+
}
182+
if err == nil && tt.err != "" {
183+
t.Errorf("got no error, want %s", tt.err)
184+
} else if err != nil && err.Error() != tt.err {
185+
t.Errorf("ERROR: %s, want %s", err, tt.err)
186+
} else if err == nil && !reflect.DeepEqual(config, tt.want) {
187+
t.Errorf("want:\n%v\ngot:\n%v", tt.want, config)
188+
}
189+
}
190+
}

0 commit comments

Comments
 (0)