Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 17a0bb3

Browse files
committedMar 11, 2025·
add "sshproxyctl forget persist"
1 parent a49296b commit 17a0bb3

File tree

5 files changed

+153
-23
lines changed

5 files changed

+153
-23
lines changed
 

‎README.asciidoc

+2
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ Version 2 brings a lot of changes to sshproxy:
157157
`sshproxyctl forget host -all|-host HOST [-port PORT]`
158158
- `sshproxyctl error_banner` (without any parameter) has been removed and
159159
replaced by `sshproxyctl forget error_banner`
160+
- `sshproxyctl forget persist [-user USER] [-service SERVICE] [-host HOST] [-port PORT]`
161+
has been added
160162

161163
Copying
162164
-------

‎cmd/sshproxyctl/sshproxyctl.go

+36-6
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,24 @@ func disableHost(host, port, configFile string) error {
525525
return cli.SetHost(key, utils.Disabled, time.Now())
526526
}
527527

528+
func forgetPersist(user, service, host, port, configFile string) error {
529+
cli := mustInitEtcdClient(configFile)
530+
defer cli.Close()
531+
532+
history, err := cli.GetHistory(user, service, host, port)
533+
if err != nil {
534+
return err
535+
}
536+
537+
for _, kv := range history {
538+
err := cli.DelHistory(kv.User)
539+
if err != nil {
540+
return err
541+
}
542+
}
543+
return nil
544+
}
545+
528546
func setErrorBanner(errorBanner string, expire time.Time, configFile string) error {
529547
cli := mustInitEtcdClient(configFile)
530548
defer cli.Close()
@@ -598,7 +616,7 @@ The commands are:
598616
version show version number and exit
599617
show show states present in etcd
600618
enable enable a host in etcd
601-
forget forget a host/error_banner in etcd
619+
forget forget a host/error_banner/persist in etcd
602620
disable disable a host in etcd
603621
error_banner set the error banner in etcd
604622
@@ -676,17 +694,22 @@ Enable a previously disabled host in etcd.
676694
return fs
677695
}
678696

679-
func newForgetParser(allFlag *bool, hostString *string, portString *string) *flag.FlagSet {
697+
func newForgetParser(allFlag *bool, hostString, portString, userString, serviceString *string) *flag.FlagSet {
680698
fs := flag.NewFlagSet("forget", flag.ExitOnError)
681699
fs.BoolVar(allFlag, "all", false, "forget all hosts present in config")
682700
fs.StringVar(hostString, "host", "", "hostname to forget (can be a nodeset)")
683701
fs.StringVar(portString, "port", "", "port to forget (can be a nodeset)")
702+
fs.StringVar(userString, "user", "", "forget all persistent connections of this user")
703+
fs.StringVar(serviceString, "service", "", "forget all persistent connections of this service")
684704
fs.Usage = func() {
685705
fmt.Fprintf(flag.CommandLine.Output(), `Usage: %s forget COMMAND [OPTIONS]
686706
687-
The commands are:
688-
host -all|-host HOST [-port PORT] forget a host in etcd
689-
error_banner forget the error_banner in etcd
707+
The cammands are:
708+
host -all|-host HOST [-port PORT] forget a host in etcd
709+
error_banner forget the error_banner in etcd
710+
persist [-user USER] [-service SERVICE] [-host HOST] [-port PORT] forget a persistent connection in etcd
711+
(needs at least one option)
712+
(only connections matching all the options are forgotten)
690713
691714
The options are:
692715
`, os.Args[0])
@@ -866,13 +889,14 @@ func main() {
866889
var sourceString string
867890
var hostString string
868891
var portString string
892+
var serviceString string
869893

870894
parsers := map[string]*flag.FlagSet{
871895
"help": newHelpParser(),
872896
"version": newVersionParser(),
873897
"show": newShowParser(&csvFlag, &jsonFlag, &allFlag, &userString, &groupsString, &sourceString),
874898
"enable": newEnableParser(&allFlag, &hostString, &portString),
875-
"forget": newForgetParser(&allFlag, &hostString, &portString),
899+
"forget": newForgetParser(&allFlag, &hostString, &portString, &userString, &serviceString),
876900
"disable": newDisableParser(&allFlag, &hostString, &portString),
877901
"error_banner": newErrorBannerParser(&expire),
878902
}
@@ -977,6 +1001,12 @@ func main() {
9771001
}
9781002
case "error_banner":
9791003
delErrorBanner(*configFile)
1004+
case "persist":
1005+
if userString == "" && serviceString == "" && hostString == "" && portString == "" {
1006+
fmt.Fprintf(os.Stderr, "ERROR: missing '-user', '-service', '-host' or '-port'\n\n")
1007+
p.Usage()
1008+
}
1009+
forgetPersist(userString, serviceString, hostString, portString, *configFile)
9801010
}
9811011
case "disable":
9821012
p := parsers[cmd]

‎doc/sshproxyctl.txt

+4
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ COMMANDS
7070
*forget error_banner*::
7171
Remove the error banner in etcd.
7272

73+
*forget persist [-user USER] [-service SERVICE] [-host HOST] [-port PORT]*::
74+
Forget a persistent connection in etcd. Needs at least one option.
75+
Only connections matching all the options are forgotten.
76+
7377
*error_banner [-expire EXPIRATION] MESSAGE*::
7478
Set the error banner in etcd. 'MESSAGE' can be multiline. The error
7579
banner is displayed to the client when no backend can be reached (more

‎pkg/utils/etcd.go

+52-17
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,14 @@ import (
2020
"fmt"
2121
"os"
2222
"regexp"
23+
"slices"
2324
"sort"
2425
"strconv"
2526
"strings"
2627
"time"
2728

29+
"github.com/cea-hpc/sshproxy/pkg/nodesets"
30+
2831
"github.com/op/go-logging"
2932
"go.etcd.io/etcd/client/v3"
3033
"go.uber.org/zap"
@@ -399,6 +402,18 @@ func (c *Client) DelHost(hostport string) error {
399402
return nil
400403
}
401404

405+
// DelHistory deletes a history key (passed as "user@service") in etcd.
406+
func (c *Client) DelHistory(history string) error {
407+
key := toHistoryKey(history)
408+
ctx, cancel := context.WithTimeout(context.Background(), c.requestTimeout)
409+
_, err := c.cli.Delete(ctx, key, clientv3.WithPrefix())
410+
cancel()
411+
if err != nil {
412+
return err
413+
}
414+
return nil
415+
}
416+
402417
// SetHost sets a host (passed as "host:port") state and last checked time (ts)
403418
// in etcd.
404419
func (c *Client) SetHost(hostport string, state State, ts time.Time) error {
@@ -666,7 +681,7 @@ func (c *Client) GetAllHosts() ([]*FlatHost, error) {
666681
}
667682
}
668683

669-
history, err := c.GetAllHistory()
684+
history, err := c.GetHistory("", "", "", "")
670685
if err != nil {
671686
return nil, fmt.Errorf("ERROR: getting history from etcd: %v", err)
672687
}
@@ -748,7 +763,7 @@ func (c *Client) GetAllUsers(allFlag bool) ([]*FlatUser, error) {
748763
}
749764

750765
if allFlag {
751-
history, err := c.GetAllHistory()
766+
history, err := c.GetHistory("", "", "", "")
752767
if err != nil {
753768
return nil, fmt.Errorf("ERROR: getting history from etcd: %v", err)
754769
}
@@ -863,36 +878,56 @@ type FlatHistory struct {
863878
TTL int64
864879
}
865880

866-
// GetAllHistory returns a list of all history keys present in etcd.
867-
func (c *Client) GetAllHistory() ([]*FlatHistory, error) {
881+
// GetHistory returns a list of matching history keys present in etcd.
882+
func (c *Client) GetHistory(user, service, host, port string) ([]*FlatHistory, error) {
868883
ctx, cancel := context.WithTimeout(context.Background(), c.requestTimeout)
869-
resp, err := c.cli.Get(ctx, etcdHistoryPath, clientv3.WithPrefix(), clientv3.WithSort(clientv3.SortByKey, clientv3.SortAscend))
884+
resp, err := c.cli.Get(ctx, etcdHistoryPath+"/"+user, clientv3.WithPrefix(), clientv3.WithSort(clientv3.SortByKey, clientv3.SortAscend))
870885
defer cancel()
871886
if err != nil {
872887
return nil, err
873888
}
874889

875-
history := make([]*FlatHistory, len(resp.Kvs))
876-
for i, ev := range resp.Kvs {
890+
_, nodesetDlclose, nodesetExpand := nodesets.InitExpander()
891+
defer nodesetDlclose()
892+
hosts, err := nodesetExpand(host)
893+
if err != nil {
894+
return nil, err
895+
}
896+
ports, err := nodesetExpand(port)
897+
if err != nil {
898+
return nil, err
899+
}
900+
var history []*FlatHistory
901+
for _, ev := range resp.Kvs {
877902
subkey := string(ev.Key)[len(etcdHistoryPath)+1:]
878903
fields := strings.Split(subkey, "/")
879904
if len(fields) != 2 {
880905
return nil, fmt.Errorf("bad key format %s", subkey)
881906
}
882-
883-
v := &FlatHistory{}
884-
v.User = fields[0]
885-
v.Dest = string(ev.Value)
886-
leaseID, err := strconv.Atoi(fields[1])
907+
evHost, evPort, err := SplitHostPort(string(ev.Value))
887908
if err != nil {
888909
return nil, err
889910
}
890-
ttl, err := c.cli.TimeToLive(ctx, clientv3.LeaseID(leaseID))
891-
if err != nil {
892-
return nil, err
911+
912+
if (user == "" && service == "" && host == "" && port == "") ||
913+
((user == "" || strings.Contains("/"+fields[0], "/"+user+"@")) &&
914+
(service == "" || strings.Contains(fields[0]+"/", "@"+service+"/")) &&
915+
(host == "" || slices.Contains(hosts, evHost)) &&
916+
(port == "" || slices.Contains(ports, evPort))) {
917+
v := &FlatHistory{}
918+
v.User = fields[0]
919+
v.Dest = string(ev.Value)
920+
leaseID, err := strconv.Atoi(fields[1])
921+
if err != nil {
922+
return nil, err
923+
}
924+
ttl, err := c.cli.TimeToLive(ctx, clientv3.LeaseID(leaseID))
925+
if err != nil {
926+
return nil, err
927+
}
928+
v.TTL = ttl.TTL
929+
history = append(history, v)
893930
}
894-
v.TTL = ttl.TTL
895-
history[i] = v
896931
}
897932

898933
return history, nil

‎test/fedora-image/sshproxy_test.go

+59
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,33 @@ func getEtcdConnections() ([]aggConnection, string) {
170170
return connections, jsonStr
171171
}
172172

173+
type aggUser struct {
174+
User string
175+
Service string
176+
Groups string
177+
N int
178+
BwIn int
179+
BwOut int
180+
Dest string
181+
TTL int
182+
}
183+
184+
func getEtcdAllUsers() ([]aggUser, string) {
185+
ctx := context.Background()
186+
_, stdout, _, err := runCommand(ctx, "ssh", []string{"gateway1", "--", fmt.Sprintf("%s show -json users -all", SSHPROXYCTL)}, nil, nil)
187+
if err != nil {
188+
log.Fatal(err)
189+
}
190+
191+
jsonStr := strings.TrimSpace(string(stdout))
192+
var users []aggUser
193+
if err := json.Unmarshal(stdout, &users); err != nil {
194+
log.Fatal(err)
195+
}
196+
197+
return users, jsonStr
198+
}
199+
173200
type host struct {
174201
Hostname string
175202
Port string
@@ -211,6 +238,11 @@ func forgetHost(host string) error {
211238
return err
212239
}
213240

241+
func forgetPersist() error {
242+
_, _, _, err := etcdCommand("forget persist -port 22")
243+
return err
244+
}
245+
214246
type user struct {
215247
User string
216248
Group string
@@ -603,6 +635,33 @@ func TestLongStickyConnections(t *testing.T) {
603635
}
604636
}
605637

638+
func TestForgetPersist(t *testing.T) {
639+
updateLineSSHProxyConf("etcd_keyttl", "3600")
640+
641+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
642+
defer cancel()
643+
args, _ := prepareCommand("gateway1", 2022, "hostname")
644+
_, stdout, _, err := runCommand(ctx, "ssh", args, nil, nil)
645+
if err != nil {
646+
log.Fatal(err)
647+
}
648+
dest := strings.TrimSpace(string(stdout)) + ":22"
649+
650+
users, _ := getEtcdAllUsers()
651+
if users[0].Dest != dest {
652+
t.Errorf("'Persist to' is %s, want %s", users[0].Dest, dest)
653+
}
654+
err = forgetPersist()
655+
if err != nil {
656+
log.Fatal(err)
657+
}
658+
updateLineSSHProxyConf("etcd_keyttl", "0")
659+
users, _ = getEtcdAllUsers()
660+
if users[0].Dest != "" {
661+
t.Errorf("'Persist to' is %s, want empty string", users[0].Dest)
662+
}
663+
}
664+
606665
func TestBalancedConnections(t *testing.T) {
607666
// remove old connections stored in etcd
608667
time.Sleep(4 * time.Second)

0 commit comments

Comments
 (0)
Please sign in to comment.