Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ type SSH struct {
type PasswordCallback func() (secret string, err error)

var (
authMethodCache = sync.Map{}
signerCache = sync.Map{}
defaultKeypaths = []string{"~/.ssh/id_rsa", "~/.ssh/identity", "~/.ssh/id_dsa", "~/.ssh/id_ecdsa", "~/.ssh/id_ed25519"}
dummyhostKeyPaths []string
globalOnce sync.Once
Expand Down Expand Up @@ -407,33 +407,33 @@ func (c *SSH) clientConfig() (*ssh.ClientConfig, error) { //nolint:cyclop
log.Tracef("%s: using %d passed-in auth methods", c, len(c.AuthMethods))
config.Auth = c.AuthMethods
} else if len(signers) > 0 {
log.Debugf("%s: using all keys (%d) from ssh agent because a keypath was not explicitly given", c, len(signers))
config.Auth = append(config.Auth, ssh.PublicKeys(signers...))
log.Debugf("%s: using all keys (%d) from ssh agent", c, len(signers))
}

for _, keyPath := range c.keyPaths {
if am, ok := authMethodCache.Load(keyPath); ok {
if am, ok := signerCache.Load(keyPath); ok {
switch authM := am.(type) {
case ssh.AuthMethod:
log.Tracef("%s: using cached auth method for %s", c, keyPath)
config.Auth = append(config.Auth, authM)
case ssh.Signer:
log.Tracef("%s: using cached signer for %s", c, keyPath)
signers = append(signers, authM)
case error:
log.Tracef("%s: already discarded key %s: %v", c, keyPath, authM)
default:
log.Tracef("%s: unexpected type %T for cached auth method for %s", c, am, keyPath)
}
continue
}
privateKeyAuth, err := c.pkeySigner(signers, keyPath)
signer, err := c.pkeySigner(signers, keyPath)
if err != nil {
log.Debugf("%s: failed to obtain a signer for identity %s: %v", c, keyPath, err)
// store the error so this key won't be loaded again
authMethodCache.Store(keyPath, err)
signerCache.Store(keyPath, err)
} else {
authMethodCache.Store(keyPath, privateKeyAuth)
config.Auth = append(config.Auth, privateKeyAuth)
signerCache.Store(keyPath, signer)
signers = append(signers, signer)
}
}
config.Auth = append(config.Auth, ssh.PublicKeys(signers...))

if len(config.Auth) == 0 {
return nil, fmt.Errorf("%w: no usable authentication method found", ErrCantConnect)
Expand Down Expand Up @@ -489,22 +489,22 @@ func (c *SSH) Connect() error {
return nil
}

func (c *SSH) pubkeySigner(signers []ssh.Signer, key ssh.PublicKey) (ssh.AuthMethod, error) {
func (c *SSH) pubkeySigner(signers []ssh.Signer, key ssh.PublicKey) (ssh.Signer, error) {
if len(signers) == 0 {
return nil, fmt.Errorf("%w: signer not found for public key", ErrCantConnect)
}

for _, s := range signers {
if bytes.Equal(key.Marshal(), s.PublicKey().Marshal()) {
log.Debugf("%s: signer for public key available in ssh agent", c)
return ssh.PublicKeys(s), nil
return s, nil
}
}

return nil, fmt.Errorf("%w: the provided key is a public key and is not known by agent", ErrAuthFailed)
}

func (c *SSH) pkeySigner(signers []ssh.Signer, path string) (ssh.AuthMethod, error) {
func (c *SSH) pkeySigner(signers []ssh.Signer, path string) (ssh.Signer, error) {
log.Tracef("%s: checking identity file %s", c, path)
key, err := os.ReadFile(path)
if err != nil {
Expand All @@ -520,7 +520,7 @@ func (c *SSH) pkeySigner(signers []ssh.Signer, path string) (ssh.AuthMethod, err
signer, err := ssh.ParsePrivateKey(key)
if err == nil {
log.Debugf("%s: using an unencrypted private key from %s", c, path)
return ssh.PublicKeys(signer), nil
return signer, nil
}

var ppErr *ssh.PassphraseMissingError
Expand All @@ -543,7 +543,7 @@ func (c *SSH) pkeySigner(signers []ssh.Signer, path string) (ssh.AuthMethod, err
if err != nil {
return nil, fmt.Errorf("%w: protected key %s decoding failed: %w", ErrCantConnect, path, err)
}
return ssh.PublicKeys(signer), nil
return signer, nil
}
}

Expand Down