diff --git a/README.md b/README.md index 705605e..7d1bd91 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,16 @@ Some SSH arguments have default values - for example, the default value for given Host/keyword pair exists in the config, we'll return a default for the keyword if one exists. +### Reloading SSH config files +Once the first call to `Get()`, `GetStrict()`, `GetAll()`, or `GetAllStrict()` +has been made, the contents of the config files will be cached for all future +calls to any of those functions. The `ReloadConfigs()` function will reset +this cache and replace it with the current config file contents. + +```go +ssh_config.ReloadConfigs() +``` + ### Manipulating SSH config files Here's how you can manipulate an SSH config file, and then write it back to diff --git a/config.go b/config.go index 4816e67..a26e4d3 100644 --- a/config.go +++ b/config.go @@ -8,7 +8,7 @@ // the host name to match on ("example.com"), and the second argument is the key // you want to retrieve ("Port"). The keywords are case insensitive. // -// port := ssh_config.Get("myhost", "Port") +// port := ssh_config.Get("myhost", "Port") // // You can also manipulate an SSH config file and then print it or write it back // to disk. @@ -59,7 +59,7 @@ type UserSettings struct { systemConfigFinder configFinder userConfig *Config userConfigFinder configFinder - loadConfigs sync.Once + loadConfigs *sync.Once onceErr error } @@ -167,6 +167,14 @@ func GetAllStrict(alias, key string) ([]string, error) { return DefaultUserSettings.GetAllStrict(alias, key) } +// ReloadConfigs clears the cached config data and freshly loads the config +// files again. +// +// ReloadConfigs is a wrapper around DefaultUserSettings.ReloadConfigs. +func ReloadConfigs() { + DefaultUserSettings.ReloadConfigs() +} + // Get finds the first value for key within a declaration that matches the // alias. Get returns the empty string if no value was found, or if IgnoreErrors // is false and we could not parse the configuration file. Use GetStrict to @@ -272,6 +280,9 @@ func (u *UserSettings) ConfigFinder(f func() string) { } func (u *UserSettings) doLoadConfigs() { + if u.loadConfigs == nil { + u.loadConfigs = new(sync.Once) + } u.loadConfigs.Do(func() { var filename string var err error @@ -310,6 +321,13 @@ func (u *UserSettings) doLoadConfigs() { }) } +// ReloadConfigs clears the cached config data and freshly loads the config +// files again. +func (u *UserSettings) ReloadConfigs() { + u.loadConfigs = new(sync.Once) + u.doLoadConfigs() +} + func parseFile(filename string) (*Config, error) { return parseWithDepth(filename, 0) } diff --git a/config_test.go b/config_test.go index 11b203d..b296ee3 100644 --- a/config_test.go +++ b/config_test.go @@ -259,6 +259,72 @@ func TestGetEqsign(t *testing.T) { } } +var modified1 = []byte(` +Host wap + User modified1 + KexAlgorithms diffie-hellman-group1-sha1 +`) + +var modified2 = []byte(` +Host wap + User modified2 + KexAlgorithms diffie-hellman-group1-sha1 +`) + +func TestCachedConfig(t *testing.T) { + us := &UserSettings{ + userConfigFinder: testConfigFinder("testdata/modified"), + } + + err1 := os.WriteFile("testdata/modified", modified1, 0644) + if err1 != nil { + t.Errorf("error writing to file: %v", err1) + } + + val1 := us.Get("wap", "User") + if val1 != "modified1" { + t.Errorf("expected to find User modified1, got %q", val1) + } + + err2 := os.WriteFile("testdata/modified", modified2, 0644) + if err1 != nil { + t.Errorf("error writing to file: %v", err2) + } + + val2 := us.Get("wap", "User") + if val2 != "modified1" { + t.Errorf("expected to find User modified1, got %q", val2) + } +} + +func TestReloadConfigs(t *testing.T) { + us := &UserSettings{ + userConfigFinder: testConfigFinder("testdata/modified"), + } + + err1 := os.WriteFile("testdata/modified", modified1, 0644) + if err1 != nil { + t.Errorf("error writing to file: %v", err1) + } + + val1 := us.Get("wap", "User") + if val1 != "modified1" { + t.Errorf("expected to find User modified1, got %q", val1) + } + + err2 := os.WriteFile("testdata/modified", modified2, 0644) + if err1 != nil { + t.Errorf("error writing to file: %v", err2) + } + + us.ReloadConfigs() + + val2 := us.Get("wap", "User") + if val2 != "modified2" { + t.Errorf("expected to find User modified2, got %q", val2) + } +} + var includeFile = []byte(` # This host should not exist, so we can use it for test purposes / it won't # interfere with any other configurations. diff --git a/testdata/modified b/testdata/modified new file mode 100644 index 0000000..1158323 --- /dev/null +++ b/testdata/modified @@ -0,0 +1,4 @@ + +Host wap + User modified2 + KexAlgorithms diffie-hellman-group1-sha1