From 7ce08bfa04076c1bcaa6b37de46833e2abdb29c6 Mon Sep 17 00:00:00 2001 From: Jason Date: Mon, 24 Nov 2025 17:02:27 -0800 Subject: [PATCH 1/4] Intraprocess credential caching --- auth.go | 175 +++++++++++++++++++++----------------- auth_oauth.go | 106 +++++++++++------------ secure_storage_manager.go | 83 ++++++++++++++++++ 3 files changed, 233 insertions(+), 131 deletions(-) diff --git a/auth.go b/auth.go index 361b3ea0d..59290630a 100644 --- a/auth.go +++ b/auth.go @@ -430,15 +430,15 @@ func authenticate( if !respd.Success { logger.WithContext(ctx).Errorln("Authentication FAILED") sc.rest.TokenAccessor.SetTokens("", "", -1) - if sessionParameters[clientRequestMfaToken] == true { - credentialsStorage.deleteCredential(lease, newMfaTokenSpec(sc.cfg.Host, sc.cfg.User)) - } - if sessionParameters[clientStoreTemporaryCredential] == true && sc.cfg.Authenticator == AuthTypeExternalBrowser { - credentialsStorage.deleteCredential(lease, newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) - } - if sessionParameters[clientStoreTemporaryCredential] == true && sc.cfg.Authenticator.isOauthNativeFlow() { - credentialsStorage.deleteCredential(lease, newOAuthAccessTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) - } + if sessionParameters[clientRequestMfaToken] == true { + _ = deleteCredentialWithLease(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User)) + } + if sessionParameters[clientStoreTemporaryCredential] == true && sc.cfg.Authenticator == AuthTypeExternalBrowser { + _ = deleteCredentialWithLease(newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) + } + if sessionParameters[clientStoreTemporaryCredential] == true && sc.cfg.Authenticator.isOauthNativeFlow() { + _ = deleteCredentialWithLease(newOAuthAccessTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) + } code, err := strconv.Atoi(respd.Code) if err != nil { return nil, err @@ -452,19 +452,19 @@ func authenticate( logger.WithContext(ctx).Info("Authentication SUCCESS") sc.rest.TokenAccessor.SetTokens(respd.Data.Token, respd.Data.MasterToken, respd.Data.SessionID) - if sessionParameters[clientRequestMfaToken] == true && sc.cfg.Authenticator == AuthTypeUsernamePasswordMFA { - token := respd.Data.MfaToken - credentialsStorage.setCredential(lease, newMfaTokenSpec(sc.cfg.Host, sc.cfg.User), token) - } - - if sessionParameters[clientStoreTemporaryCredential] == true && sc.cfg.Authenticator == AuthTypeExternalBrowser { - token := respd.Data.IDToken - // XXX: for some reason, token is empty here some times and we - // don't want to clear the cache, so let's skip it if it's empty - if token != "" { - credentialsStorage.setCredential(lease, newIDTokenSpec(sc.cfg.Host, sc.cfg.User), token) - } - } + if sessionParameters[clientRequestMfaToken] == true && sc.cfg.Authenticator == AuthTypeUsernamePasswordMFA { + token := respd.Data.MfaToken + _ = setCredentialWithLease(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User), token) + } + + if sessionParameters[clientStoreTemporaryCredential] == true && sc.cfg.Authenticator == AuthTypeExternalBrowser { + token := respd.Data.IDToken + // XXX: for some reason, token is empty here some times and we + // don't want to clear the cache, so let's skip it if it's empty + if token != "" { + _ = setCredentialWithLease(newIDTokenSpec(sc.cfg.Host, sc.cfg.User), token) + } + } return &respd.Data, nil } @@ -633,13 +633,13 @@ func authenticateByAuthorizationCode(sc *snowflakeConn, lease *Lease) (string, e lockKey := newOAuthAuthorizationCodeLockKey(oauthClient.tokenURL(), sc.cfg.User) valueAwaiter := valueAwaitHolder.get(lockKey) defer valueAwaiter.resumeOne() - token, err := awaitValue(valueAwaiter, func() (string, error) { - return credentialsStorage.getCredential(lease, newOAuthAccessTokenSpec(oauthClient.tokenURL(), sc.cfg.User)) - }, func(s string, err error) bool { - return s != "" - }, func() string { - return "" - }) + token, err := awaitValue(valueAwaiter, func() (string, error) { + return getCredentialFast(newOAuthAccessTokenSpec(oauthClient.tokenURL(), sc.cfg.User)) + }, func(s string, err error) bool { + return s != "" + }, func() string { + return "" + }) if err != nil || token != "" { return token, err } @@ -742,11 +742,14 @@ func authenticateWithConfig(sc *snowflakeConn) error { mfaTokenLockKey := newMfaTokenLockKey(sc.cfg.Host, sc.cfg.User) idTokenLockKey := newIDTokenLockKey(sc.cfg.Host, sc.cfg.User) - lease, err := credentialsStorage.acquireLease() - if err != nil { - return err - } - defer lease.Release() + // Avoid acquiring the global lease up front; use fast-path lookups that + // consult the in-memory cache first and only acquire the lease on misses. + var lease *Lease + defer func() { + if lease != nil { + lease.Release() + } + }() key := extBrowserBackoffKey(sc.cfg.Host, sc.cfg.User) @@ -762,7 +765,7 @@ func authenticateWithConfig(sc *snowflakeConn) error { sc.cfg.IDToken, _ = awaitValue( valueAwaiter, func() (string, error) { - tok, _ := credentialsStorage.getCredential(lease, newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) + tok, _ := getCredentialFast(newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) return tok, nil }, func(s string, err error) bool { @@ -773,10 +776,10 @@ func authenticateWithConfig(sc *snowflakeConn) error { }, ) - } else if sc.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { - tok, _ := credentialsStorage.getCredential(lease, newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) - sc.cfg.IDToken = tok - } + } else if sc.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { + tok, _ := getCredentialFast(newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) + sc.cfg.IDToken = tok + } // Disable console login by default if sc.cfg.DisableConsoleLogin == configBoolNotSet { sc.cfg.DisableConsoleLogin = ConfigBoolTrue @@ -793,7 +796,7 @@ func authenticateWithConfig(sc *snowflakeConn) error { sc.cfg.MfaToken, _ = awaitValue( valueAwaiter, func() (string, error) { - tok, err := credentialsStorage.getCredential(lease, newMfaTokenSpec(sc.cfg.Host, sc.cfg.User)) + tok, err := getCredentialFast(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User)) if err != nil { logger.WithContext(sc.ctx).Warnf("failed to get MFA token from credential storage: %v", err) } @@ -804,13 +807,13 @@ func authenticateWithConfig(sc *snowflakeConn) error { return "" }, ) - } else if sc.cfg.ClientRequestMfaToken == ConfigBoolTrue { - tok, err := credentialsStorage.getCredential(lease, newMfaTokenSpec(sc.cfg.Host, sc.cfg.User)) - if err != nil { - logger.WithContext(sc.ctx).Warnf("failed to get MFA token from credential storage: %v", err) - } - sc.cfg.MfaToken = tok - } + } else if sc.cfg.ClientRequestMfaToken == ConfigBoolTrue { + tok, err := getCredentialFast(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User)) + if err != nil { + logger.WithContext(sc.ctx).Warnf("failed to get MFA token from credential storage: %v", err) + } + sc.cfg.MfaToken = tok + } } } @@ -829,28 +832,36 @@ func authenticateWithConfig(sc *snowflakeConn) error { lastFail.Delete(key) } - samlResponse, proofKey, err = authenticateByExternalBrowser( - sc.ctx, - lease, - sc.rest, - sc.cfg.Authenticator.String(), - sc.cfg.Application, - sc.cfg.Account, - sc.cfg.User, - sc.cfg.ExternalBrowserTimeout, - sc.cfg.DisableConsoleLogin) + // External browser path requires a live lease for periodic renewals + if lease == nil { + lease, err = credentialsStorage.acquireLease() + if err != nil { + sc.cleanup() + return err + } + } + samlResponse, proofKey, err = authenticateByExternalBrowser( + sc.ctx, + lease, + sc.rest, + sc.cfg.Authenticator.String(), + sc.cfg.Application, + sc.cfg.Account, + sc.cfg.User, + sc.cfg.ExternalBrowserTimeout, + sc.cfg.DisableConsoleLogin) if err != nil { sc.cleanup() return err } } } - authData, err = authenticate( - sc.ctx, - lease, - sc, - samlResponse, - proofKey) + authData, err = authenticate( + sc.ctx, + lease, + sc, + samlResponse, + proofKey) if err != nil { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { sc.cleanup() @@ -861,9 +872,9 @@ func authenticateWithConfig(sc *snowflakeConn) error { switch { // Case 1: cached ID token failed -> clear + try one interactive refresh - case sc.cfg.Authenticator == AuthTypeExternalBrowser && sc.cfg.IDToken != "": - credentialsStorage.deleteCredential(lease, newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) - sc.cfg.IDToken = "" + case sc.cfg.Authenticator == AuthTypeExternalBrowser && sc.cfg.IDToken != "": + _ = deleteCredentialWithLease(newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) + sc.cfg.IDToken = "" // NOTE on tabstorms: // We intentionally do NOT pre-gate the cached-ID-token fallback with a LoadOrStore @@ -874,11 +885,19 @@ func authenticateWithConfig(sc *snowflakeConn) error { // considered low probability and low impact // reenter interactive flow - samlResponse, proofKey, err = authenticateByExternalBrowser( - sc.ctx, lease, sc.rest, sc.cfg.Authenticator.String(), - sc.cfg.Application, sc.cfg.Account, sc.cfg.User, - sc.cfg.ExternalBrowserTimeout, sc.cfg.DisableConsoleLogin, - ) + if lease == nil { + lease, err = credentialsStorage.acquireLease() + if err != nil { + // let SAML-phase failure remain retryable; do not set backoff + sc.cleanup() + return err + } + } + samlResponse, proofKey, err = authenticateByExternalBrowser( + sc.ctx, lease, sc.rest, sc.cfg.Authenticator.String(), + sc.cfg.Application, sc.cfg.Account, sc.cfg.User, + sc.cfg.ExternalBrowserTimeout, sc.cfg.DisableConsoleLogin, + ) if err != nil { // let SAML-phase failure remain retryable; do not set backoff sc.cleanup() @@ -892,12 +911,12 @@ func authenticateWithConfig(sc *snowflakeConn) error { fallthrough // Case 2: still failing, but could be an OAuth refreshable error - case errors.As(err, &se) && slices.Contains(refreshOAuthTokenErrorCodes, strconv.Itoa(se.Number)): - credentialsStorage.deleteCredential(lease, newOAuthAccessTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) + case errors.As(err, &se) && slices.Contains(refreshOAuthTokenErrorCodes, strconv.Itoa(se.Number)): + _ = deleteCredentialWithLease(newOAuthAccessTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) - if sc.cfg.Authenticator == AuthTypeOAuthAuthorizationCode { - doRefreshTokenWithLease(sc, lease) - } + if sc.cfg.Authenticator == AuthTypeOAuthAuthorizationCode { + doRefreshTokenWithLease(sc, lease) + } // if refreshing succeeds for authorization code, we will take a token from cache // if it fails, we will just run the full flow @@ -941,7 +960,7 @@ func doRefreshTokenWithLock(sc *snowflakeConn, lease *Lease) { if _, err = getValueWithLock(chooseLockerForAuth(sc.cfg), lockKey, func() (string, error) { if err = oauthClient.refreshToken(lease); err != nil { logger.Warnf("cannot refresh token. %v", err) - credentialsStorage.deleteCredential(lease, newOAuthRefreshTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) + _ = deleteCredentialWithLease(newOAuthRefreshTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) return "", err } return "", nil @@ -957,7 +976,7 @@ func doRefreshTokenWithLease(sc *snowflakeConn, lease *Lease) { } else { if rfErr := oauthClient.refreshToken(lease); rfErr != nil { logger.Warnf("cannot refresh token. %v", rfErr) - credentialsStorage.deleteCredential(lease, newOAuthRefreshTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) + _ = deleteCredentialWithLease(newOAuthRefreshTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) } } diff --git a/auth_oauth.go b/auth_oauth.go index 8d83023af..cdd3440b3 100644 --- a/auth_oauth.go +++ b/auth_oauth.go @@ -91,16 +91,16 @@ type oauthBrowserResult struct { } func (oauthClient *oauthClient) authenticateByOAuthAuthorizationCode(lease *Lease) (string, error) { - accessTokenSpec := oauthClient.accessTokenSpec() - if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { - if accessToken, err := credentialsStorage.getCredential(lease, accessTokenSpec); accessToken != "" { - logger.Debugf("Access token retrieved from cache: %v", err) - return accessToken, nil - } - if refreshToken, _ := credentialsStorage.getCredential(lease, oauthClient.refreshTokenSpec()); refreshToken != "" { - return "", &SnowflakeError{Number: ErrMissingAccessATokenButRefreshTokenPresent} - } - } + accessTokenSpec := oauthClient.accessTokenSpec() + if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { + if accessToken, err := getCredentialFast(accessTokenSpec); accessToken != "" { + logger.Debugf("Access token retrieved from cache: %v", err) + return accessToken, nil + } + if refreshToken, _ := getCredentialFast(oauthClient.refreshTokenSpec()); refreshToken != "" { + return "", &SnowflakeError{Number: ErrMissingAccessATokenButRefreshTokenPresent} + } + } logger.Debugf("Access token not present in cache, running full auth code flow") resultChan := make(chan oauthBrowserResult, 1) @@ -121,13 +121,13 @@ func (oauthClient *oauthClient) authenticateByOAuthAuthorizationCode(lease *Leas case <-time.After(oauthClient.cfg.ExternalBrowserTimeout): return "", errors.New("authentication via browser timed out") case result := <-resultChan: - if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { - logger.Debug("saving oauth access token in cache") - credentialsStorage.setCredential(lease, oauthClient.accessTokenSpec(), result.accessToken) - credentialsStorage.setCredential(lease, oauthClient.refreshTokenSpec(), result.refreshToken) - } - return result.accessToken, result.err - } + if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { + logger.Debug("saving oauth access token in cache") + _ = setCredentialWithLease(oauthClient.accessTokenSpec(), result.accessToken) + _ = setCredentialWithLease(oauthClient.refreshTokenSpec(), result.refreshToken) + } + return result.accessToken, result.err + } } func (oauthClient *oauthClient) doAuthenticateByOAuthAuthorizationCode(tcpListener *net.TCPListener, callbackPort int) oauthBrowserResult { @@ -353,14 +353,14 @@ func (provider *browserBasedAuthorizationCodeProvider) createCodeVerifier() stri func (oauthClient *oauthClient) authenticateByOAuthClientCredentials(lease *Lease) (string, error) { accessTokenSpec := oauthClient.accessTokenSpec() - if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { - if accessToken, err := credentialsStorage.getCredential(lease, accessTokenSpec); accessToken != "" { - if err != nil { - return "", fmt.Errorf("error retrieving access token from cache: %w", err) - } - return accessToken, nil - } - } + if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { + if accessToken, err := getCredentialFast(accessTokenSpec); accessToken != "" { + if err != nil { + return "", fmt.Errorf("error retrieving access token from cache: %w", err) + } + return accessToken, nil + } + } oauth2Cfg, err := oauthClient.buildClientCredentialsConfig() if err != nil { return "", err @@ -369,10 +369,10 @@ func (oauthClient *oauthClient) authenticateByOAuthClientCredentials(lease *Leas if err != nil { return "", err } - if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { - credentialsStorage.setCredential(lease, accessTokenSpec, token.AccessToken) - } - return token.AccessToken, nil + if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { + _ = setCredentialWithLease(accessTokenSpec, token.AccessToken) + } + return token.AccessToken, nil } func (oauthClient *oauthClient) buildClientCredentialsConfig() (*clientcredentials.Config, error) { @@ -388,16 +388,16 @@ func (oauthClient *oauthClient) buildClientCredentialsConfig() (*clientcredentia } func (oauthClient *oauthClient) refreshToken(lease *Lease) error { - if oauthClient.cfg.ClientStoreTemporaryCredential != ConfigBoolTrue { - logger.Debug("credentials storage is disabled, cannot use refresh tokens") - return nil - } - refreshTokenSpec := newOAuthRefreshTokenSpec(oauthClient.cfg.OauthTokenRequestURL, oauthClient.cfg.User) - refreshToken, _ := credentialsStorage.getCredential(lease, refreshTokenSpec) - if refreshToken == "" { - logger.Debug("no refresh token in cache, full flow must be run") - return nil - } + if oauthClient.cfg.ClientStoreTemporaryCredential != ConfigBoolTrue { + logger.Debug("credentials storage is disabled, cannot use refresh tokens") + return nil + } + refreshTokenSpec := newOAuthRefreshTokenSpec(oauthClient.cfg.OauthTokenRequestURL, oauthClient.cfg.User) + refreshToken, _ := getCredentialFast(refreshTokenSpec) + if refreshToken == "" { + logger.Debug("no refresh token in cache, full flow must be run") + return nil + } body := url.Values{} body.Add("grant_type", "refresh_token") body.Add("refresh_token", refreshToken) @@ -417,24 +417,24 @@ func (oauthClient *oauthClient) refreshToken(lease *Lease) error { logger.Warnf("error while closing response body for %v. %v", req.URL, err) } }() - if resp.StatusCode != 200 { - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - credentialsStorage.deleteCredential(lease, refreshTokenSpec) - return errors.New(string(respBody)) - } + if resp.StatusCode != 200 { + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + _ = deleteCredentialWithLease(refreshTokenSpec) + return errors.New(string(respBody)) + } var tokenResponse tokenExchangeResponseBody if err = json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil { return err } - accessTokenSpec := oauthClient.accessTokenSpec() - credentialsStorage.setCredential(lease, accessTokenSpec, tokenResponse.AccessToken) - if tokenResponse.RefreshToken != "" { - credentialsStorage.setCredential(lease, refreshTokenSpec, tokenResponse.RefreshToken) - } - return nil + accessTokenSpec := oauthClient.accessTokenSpec() + _ = setCredentialWithLease(accessTokenSpec, tokenResponse.AccessToken) + if tokenResponse.RefreshToken != "" { + _ = setCredentialWithLease(refreshTokenSpec, tokenResponse.RefreshToken) + } + return nil } type tokenExchangeResponseBody struct { diff --git a/secure_storage_manager.go b/secure_storage_manager.go index ab51993c9..164dc2e09 100644 --- a/secure_storage_manager.go +++ b/secure_storage_manager.go @@ -14,6 +14,7 @@ import ( "sync" "sync/atomic" "time" + "github.com/99designs/keyring" ) @@ -119,6 +120,59 @@ type secureStorageManager interface { var credentialsStorage = newSecureStorageManager() +// Helper fast-paths to minimize lease contention for common operations. +// These acquire a lease only if needed. + +// getCredentialFast attempts an in-memory read first (if available), +// otherwise acquires a lease to consult the persistent cache. +func getCredentialFast(tokenSpec *secureTokenSpec) (string, error) { + // Try in-memory if file-based storage + if fb, ok := credentialsStorage.(*fileBasedSecureStorageManager); ok { + key, err := tokenSpec.buildKey() + if err != nil { + return "", err + } + fb.memMu.RLock() + if v, ok := fb.mem[key]; ok { + fb.memMu.RUnlock() + return v, nil + } + fb.memMu.RUnlock() + } + + // Miss: acquire lease and consult backing store + lease, err := credentialsStorage.acquireLease() + if err != nil { + return "", err + } + if lease != nil { + defer lease.Release() + } + return credentialsStorage.getCredential(lease, tokenSpec) +} + +func setCredentialWithLease(tokenSpec *secureTokenSpec, value string) error { + lease, err := credentialsStorage.acquireLease() + if err != nil { + return err + } + if lease != nil { + defer lease.Release() + } + return credentialsStorage.setCredential(lease, tokenSpec, value) +} + +func deleteCredentialWithLease(tokenSpec *secureTokenSpec) error { + lease, err := credentialsStorage.acquireLease() + if err != nil { + return err + } + if lease != nil { + defer lease.Release() + } + return credentialsStorage.deleteCredential(lease, tokenSpec) +} + func newSecureStorageManager() secureStorageManager { var ssm secureStorageManager var err error @@ -139,6 +193,9 @@ func newSecureStorageManager() secureStorageManager { type fileBasedSecureStorageManager struct { credDirPath string leaseHandler *LeaseHandler + // in-memory fast path cache to reduce file lock contention within a process + memMu sync.RWMutex + mem map[string]string } func newFileBasedSecureStorageManager() (*fileBasedSecureStorageManager, error) { @@ -153,6 +210,7 @@ func newFileBasedSecureStorageManager() (*fileBasedSecureStorageManager, error) ssm := &fileBasedSecureStorageManager{ credDirPath: credDirPath, leaseHandler: leaseHandler, + mem: make(map[string]string), } return ssm, nil } @@ -299,6 +357,14 @@ func (ssm *fileBasedSecureStorageManager) setCredential(lease *Lease, tokenSpec return err } + // Update in-memory cache fast path first to satisfy concurrent readers + // in this process without hitting the filesystem. + if credentialsKey != "" && value != "" { + ssm.memMu.Lock() + ssm.mem[credentialsKey] = value + ssm.memMu.Unlock() + } + return ssm.withCacheFile(lease, func(cacheFile *os.File) error { credCache, err := ssm.readTemporaryCacheFile(cacheFile) if err != nil { @@ -397,6 +463,14 @@ func (ssm *fileBasedSecureStorageManager) getCredential(lease *Lease, tokenSpec return "", err } + // Fast path: serve from in-memory cache if present + ssm.memMu.RLock() + if v, ok := ssm.mem[credentialsKey]; ok { + ssm.memMu.RUnlock() + return v, nil + } + ssm.memMu.RUnlock() + ret := "" err = ssm.withCacheFile(lease, func(cacheFile *os.File) error { credCache, err := ssm.readTemporaryCacheFile(cacheFile) @@ -415,6 +489,10 @@ func (ssm *fileBasedSecureStorageManager) getCredential(lease *Lease, tokenSpec } ret = credStr + // Populate in-memory cache for subsequent readers + ssm.memMu.Lock() + ssm.mem[credentialsKey] = credStr + ssm.memMu.Unlock() return nil }) return ret, err @@ -503,6 +581,11 @@ func (ssm *fileBasedSecureStorageManager) deleteCredential(lease *Lease, tokenSp return err } + // Remove from in-memory cache first + ssm.memMu.Lock() + delete(ssm.mem, credentialsKey) + ssm.memMu.Unlock() + return ssm.withCacheFile(lease, func(cacheFile *os.File) error { credCache, err := ssm.readTemporaryCacheFile(cacheFile) if err != nil { From 5d801d78325b9139a2103d62b79dd6ab1474755d Mon Sep 17 00:00:00 2001 From: Jason Date: Mon, 24 Nov 2025 22:29:54 -0800 Subject: [PATCH 2/4] go fmt --- auth.go | 196 +++++++++++++++++++++++------------------------ auth_oauth.go | 106 ++++++++++++------------- connection.go | 2 +- value_awaiter.go | 2 +- 4 files changed, 153 insertions(+), 153 deletions(-) diff --git a/auth.go b/auth.go index 59290630a..10aad43cc 100644 --- a/auth.go +++ b/auth.go @@ -430,15 +430,15 @@ func authenticate( if !respd.Success { logger.WithContext(ctx).Errorln("Authentication FAILED") sc.rest.TokenAccessor.SetTokens("", "", -1) - if sessionParameters[clientRequestMfaToken] == true { - _ = deleteCredentialWithLease(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User)) - } - if sessionParameters[clientStoreTemporaryCredential] == true && sc.cfg.Authenticator == AuthTypeExternalBrowser { - _ = deleteCredentialWithLease(newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) - } - if sessionParameters[clientStoreTemporaryCredential] == true && sc.cfg.Authenticator.isOauthNativeFlow() { - _ = deleteCredentialWithLease(newOAuthAccessTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) - } + if sessionParameters[clientRequestMfaToken] == true { + _ = deleteCredentialWithLease(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User)) + } + if sessionParameters[clientStoreTemporaryCredential] == true && sc.cfg.Authenticator == AuthTypeExternalBrowser { + _ = deleteCredentialWithLease(newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) + } + if sessionParameters[clientStoreTemporaryCredential] == true && sc.cfg.Authenticator.isOauthNativeFlow() { + _ = deleteCredentialWithLease(newOAuthAccessTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) + } code, err := strconv.Atoi(respd.Code) if err != nil { return nil, err @@ -452,19 +452,19 @@ func authenticate( logger.WithContext(ctx).Info("Authentication SUCCESS") sc.rest.TokenAccessor.SetTokens(respd.Data.Token, respd.Data.MasterToken, respd.Data.SessionID) - if sessionParameters[clientRequestMfaToken] == true && sc.cfg.Authenticator == AuthTypeUsernamePasswordMFA { - token := respd.Data.MfaToken - _ = setCredentialWithLease(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User), token) - } - - if sessionParameters[clientStoreTemporaryCredential] == true && sc.cfg.Authenticator == AuthTypeExternalBrowser { - token := respd.Data.IDToken - // XXX: for some reason, token is empty here some times and we - // don't want to clear the cache, so let's skip it if it's empty - if token != "" { - _ = setCredentialWithLease(newIDTokenSpec(sc.cfg.Host, sc.cfg.User), token) - } - } + if sessionParameters[clientRequestMfaToken] == true && sc.cfg.Authenticator == AuthTypeUsernamePasswordMFA { + token := respd.Data.MfaToken + _ = setCredentialWithLease(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User), token) + } + + if sessionParameters[clientStoreTemporaryCredential] == true && sc.cfg.Authenticator == AuthTypeExternalBrowser { + token := respd.Data.IDToken + // XXX: for some reason, token is empty here some times and we + // don't want to clear the cache, so let's skip it if it's empty + if token != "" { + _ = setCredentialWithLease(newIDTokenSpec(sc.cfg.Host, sc.cfg.User), token) + } + } return &respd.Data, nil } @@ -633,13 +633,13 @@ func authenticateByAuthorizationCode(sc *snowflakeConn, lease *Lease) (string, e lockKey := newOAuthAuthorizationCodeLockKey(oauthClient.tokenURL(), sc.cfg.User) valueAwaiter := valueAwaitHolder.get(lockKey) defer valueAwaiter.resumeOne() - token, err := awaitValue(valueAwaiter, func() (string, error) { - return getCredentialFast(newOAuthAccessTokenSpec(oauthClient.tokenURL(), sc.cfg.User)) - }, func(s string, err error) bool { - return s != "" - }, func() string { - return "" - }) + token, err := awaitValue(valueAwaiter, func() (string, error) { + return getCredentialFast(newOAuthAccessTokenSpec(oauthClient.tokenURL(), sc.cfg.User)) + }, func(s string, err error) bool { + return s != "" + }, func() string { + return "" + }) if err != nil || token != "" { return token, err } @@ -742,14 +742,14 @@ func authenticateWithConfig(sc *snowflakeConn) error { mfaTokenLockKey := newMfaTokenLockKey(sc.cfg.Host, sc.cfg.User) idTokenLockKey := newIDTokenLockKey(sc.cfg.Host, sc.cfg.User) - // Avoid acquiring the global lease up front; use fast-path lookups that - // consult the in-memory cache first and only acquire the lease on misses. - var lease *Lease - defer func() { - if lease != nil { - lease.Release() - } - }() + // Avoid acquiring the global lease up front; use fast-path lookups that + // consult the in-memory cache first and only acquire the lease on misses. + var lease *Lease + defer func() { + if lease != nil { + lease.Release() + } + }() key := extBrowserBackoffKey(sc.cfg.Host, sc.cfg.User) @@ -765,7 +765,7 @@ func authenticateWithConfig(sc *snowflakeConn) error { sc.cfg.IDToken, _ = awaitValue( valueAwaiter, func() (string, error) { - tok, _ := getCredentialFast(newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) + tok, _ := getCredentialFast(newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) return tok, nil }, func(s string, err error) bool { @@ -776,10 +776,10 @@ func authenticateWithConfig(sc *snowflakeConn) error { }, ) - } else if sc.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { - tok, _ := getCredentialFast(newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) - sc.cfg.IDToken = tok - } + } else if sc.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { + tok, _ := getCredentialFast(newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) + sc.cfg.IDToken = tok + } // Disable console login by default if sc.cfg.DisableConsoleLogin == configBoolNotSet { sc.cfg.DisableConsoleLogin = ConfigBoolTrue @@ -796,7 +796,7 @@ func authenticateWithConfig(sc *snowflakeConn) error { sc.cfg.MfaToken, _ = awaitValue( valueAwaiter, func() (string, error) { - tok, err := getCredentialFast(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User)) + tok, err := getCredentialFast(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User)) if err != nil { logger.WithContext(sc.ctx).Warnf("failed to get MFA token from credential storage: %v", err) } @@ -807,13 +807,13 @@ func authenticateWithConfig(sc *snowflakeConn) error { return "" }, ) - } else if sc.cfg.ClientRequestMfaToken == ConfigBoolTrue { - tok, err := getCredentialFast(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User)) - if err != nil { - logger.WithContext(sc.ctx).Warnf("failed to get MFA token from credential storage: %v", err) - } - sc.cfg.MfaToken = tok - } + } else if sc.cfg.ClientRequestMfaToken == ConfigBoolTrue { + tok, err := getCredentialFast(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User)) + if err != nil { + logger.WithContext(sc.ctx).Warnf("failed to get MFA token from credential storage: %v", err) + } + sc.cfg.MfaToken = tok + } } } @@ -832,36 +832,36 @@ func authenticateWithConfig(sc *snowflakeConn) error { lastFail.Delete(key) } - // External browser path requires a live lease for periodic renewals - if lease == nil { - lease, err = credentialsStorage.acquireLease() - if err != nil { - sc.cleanup() - return err - } - } - samlResponse, proofKey, err = authenticateByExternalBrowser( - sc.ctx, - lease, - sc.rest, - sc.cfg.Authenticator.String(), - sc.cfg.Application, - sc.cfg.Account, - sc.cfg.User, - sc.cfg.ExternalBrowserTimeout, - sc.cfg.DisableConsoleLogin) + // External browser path requires a live lease for periodic renewals + if lease == nil { + lease, err = credentialsStorage.acquireLease() + if err != nil { + sc.cleanup() + return err + } + } + samlResponse, proofKey, err = authenticateByExternalBrowser( + sc.ctx, + lease, + sc.rest, + sc.cfg.Authenticator.String(), + sc.cfg.Application, + sc.cfg.Account, + sc.cfg.User, + sc.cfg.ExternalBrowserTimeout, + sc.cfg.DisableConsoleLogin) if err != nil { sc.cleanup() return err } } } - authData, err = authenticate( - sc.ctx, - lease, - sc, - samlResponse, - proofKey) + authData, err = authenticate( + sc.ctx, + lease, + sc, + samlResponse, + proofKey) if err != nil { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { sc.cleanup() @@ -872,9 +872,9 @@ func authenticateWithConfig(sc *snowflakeConn) error { switch { // Case 1: cached ID token failed -> clear + try one interactive refresh - case sc.cfg.Authenticator == AuthTypeExternalBrowser && sc.cfg.IDToken != "": - _ = deleteCredentialWithLease(newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) - sc.cfg.IDToken = "" + case sc.cfg.Authenticator == AuthTypeExternalBrowser && sc.cfg.IDToken != "": + _ = deleteCredentialWithLease(newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) + sc.cfg.IDToken = "" // NOTE on tabstorms: // We intentionally do NOT pre-gate the cached-ID-token fallback with a LoadOrStore @@ -885,19 +885,19 @@ func authenticateWithConfig(sc *snowflakeConn) error { // considered low probability and low impact // reenter interactive flow - if lease == nil { - lease, err = credentialsStorage.acquireLease() - if err != nil { - // let SAML-phase failure remain retryable; do not set backoff - sc.cleanup() - return err - } - } - samlResponse, proofKey, err = authenticateByExternalBrowser( - sc.ctx, lease, sc.rest, sc.cfg.Authenticator.String(), - sc.cfg.Application, sc.cfg.Account, sc.cfg.User, - sc.cfg.ExternalBrowserTimeout, sc.cfg.DisableConsoleLogin, - ) + if lease == nil { + lease, err = credentialsStorage.acquireLease() + if err != nil { + // let SAML-phase failure remain retryable; do not set backoff + sc.cleanup() + return err + } + } + samlResponse, proofKey, err = authenticateByExternalBrowser( + sc.ctx, lease, sc.rest, sc.cfg.Authenticator.String(), + sc.cfg.Application, sc.cfg.Account, sc.cfg.User, + sc.cfg.ExternalBrowserTimeout, sc.cfg.DisableConsoleLogin, + ) if err != nil { // let SAML-phase failure remain retryable; do not set backoff sc.cleanup() @@ -910,13 +910,13 @@ func authenticateWithConfig(sc *snowflakeConn) error { } fallthrough - // Case 2: still failing, but could be an OAuth refreshable error - case errors.As(err, &se) && slices.Contains(refreshOAuthTokenErrorCodes, strconv.Itoa(se.Number)): - _ = deleteCredentialWithLease(newOAuthAccessTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) + // Case 2: still failing, but could be an OAuth refreshable error + case errors.As(err, &se) && slices.Contains(refreshOAuthTokenErrorCodes, strconv.Itoa(se.Number)): + _ = deleteCredentialWithLease(newOAuthAccessTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) - if sc.cfg.Authenticator == AuthTypeOAuthAuthorizationCode { - doRefreshTokenWithLease(sc, lease) - } + if sc.cfg.Authenticator == AuthTypeOAuthAuthorizationCode { + doRefreshTokenWithLease(sc, lease) + } // if refreshing succeeds for authorization code, we will take a token from cache // if it fails, we will just run the full flow @@ -960,7 +960,7 @@ func doRefreshTokenWithLock(sc *snowflakeConn, lease *Lease) { if _, err = getValueWithLock(chooseLockerForAuth(sc.cfg), lockKey, func() (string, error) { if err = oauthClient.refreshToken(lease); err != nil { logger.Warnf("cannot refresh token. %v", err) - _ = deleteCredentialWithLease(newOAuthRefreshTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) + _ = deleteCredentialWithLease(newOAuthRefreshTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) return "", err } return "", nil @@ -976,7 +976,7 @@ func doRefreshTokenWithLease(sc *snowflakeConn, lease *Lease) { } else { if rfErr := oauthClient.refreshToken(lease); rfErr != nil { logger.Warnf("cannot refresh token. %v", rfErr) - _ = deleteCredentialWithLease(newOAuthRefreshTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) + _ = deleteCredentialWithLease(newOAuthRefreshTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) } } diff --git a/auth_oauth.go b/auth_oauth.go index cdd3440b3..3180c7147 100644 --- a/auth_oauth.go +++ b/auth_oauth.go @@ -91,16 +91,16 @@ type oauthBrowserResult struct { } func (oauthClient *oauthClient) authenticateByOAuthAuthorizationCode(lease *Lease) (string, error) { - accessTokenSpec := oauthClient.accessTokenSpec() - if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { - if accessToken, err := getCredentialFast(accessTokenSpec); accessToken != "" { - logger.Debugf("Access token retrieved from cache: %v", err) - return accessToken, nil - } - if refreshToken, _ := getCredentialFast(oauthClient.refreshTokenSpec()); refreshToken != "" { - return "", &SnowflakeError{Number: ErrMissingAccessATokenButRefreshTokenPresent} - } - } + accessTokenSpec := oauthClient.accessTokenSpec() + if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { + if accessToken, err := getCredentialFast(accessTokenSpec); accessToken != "" { + logger.Debugf("Access token retrieved from cache: %v", err) + return accessToken, nil + } + if refreshToken, _ := getCredentialFast(oauthClient.refreshTokenSpec()); refreshToken != "" { + return "", &SnowflakeError{Number: ErrMissingAccessATokenButRefreshTokenPresent} + } + } logger.Debugf("Access token not present in cache, running full auth code flow") resultChan := make(chan oauthBrowserResult, 1) @@ -121,13 +121,13 @@ func (oauthClient *oauthClient) authenticateByOAuthAuthorizationCode(lease *Leas case <-time.After(oauthClient.cfg.ExternalBrowserTimeout): return "", errors.New("authentication via browser timed out") case result := <-resultChan: - if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { - logger.Debug("saving oauth access token in cache") - _ = setCredentialWithLease(oauthClient.accessTokenSpec(), result.accessToken) - _ = setCredentialWithLease(oauthClient.refreshTokenSpec(), result.refreshToken) - } - return result.accessToken, result.err - } + if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { + logger.Debug("saving oauth access token in cache") + _ = setCredentialWithLease(oauthClient.accessTokenSpec(), result.accessToken) + _ = setCredentialWithLease(oauthClient.refreshTokenSpec(), result.refreshToken) + } + return result.accessToken, result.err + } } func (oauthClient *oauthClient) doAuthenticateByOAuthAuthorizationCode(tcpListener *net.TCPListener, callbackPort int) oauthBrowserResult { @@ -353,14 +353,14 @@ func (provider *browserBasedAuthorizationCodeProvider) createCodeVerifier() stri func (oauthClient *oauthClient) authenticateByOAuthClientCredentials(lease *Lease) (string, error) { accessTokenSpec := oauthClient.accessTokenSpec() - if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { - if accessToken, err := getCredentialFast(accessTokenSpec); accessToken != "" { - if err != nil { - return "", fmt.Errorf("error retrieving access token from cache: %w", err) - } - return accessToken, nil - } - } + if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { + if accessToken, err := getCredentialFast(accessTokenSpec); accessToken != "" { + if err != nil { + return "", fmt.Errorf("error retrieving access token from cache: %w", err) + } + return accessToken, nil + } + } oauth2Cfg, err := oauthClient.buildClientCredentialsConfig() if err != nil { return "", err @@ -369,10 +369,10 @@ func (oauthClient *oauthClient) authenticateByOAuthClientCredentials(lease *Leas if err != nil { return "", err } - if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { - _ = setCredentialWithLease(accessTokenSpec, token.AccessToken) - } - return token.AccessToken, nil + if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { + _ = setCredentialWithLease(accessTokenSpec, token.AccessToken) + } + return token.AccessToken, nil } func (oauthClient *oauthClient) buildClientCredentialsConfig() (*clientcredentials.Config, error) { @@ -388,16 +388,16 @@ func (oauthClient *oauthClient) buildClientCredentialsConfig() (*clientcredentia } func (oauthClient *oauthClient) refreshToken(lease *Lease) error { - if oauthClient.cfg.ClientStoreTemporaryCredential != ConfigBoolTrue { - logger.Debug("credentials storage is disabled, cannot use refresh tokens") - return nil - } - refreshTokenSpec := newOAuthRefreshTokenSpec(oauthClient.cfg.OauthTokenRequestURL, oauthClient.cfg.User) - refreshToken, _ := getCredentialFast(refreshTokenSpec) - if refreshToken == "" { - logger.Debug("no refresh token in cache, full flow must be run") - return nil - } + if oauthClient.cfg.ClientStoreTemporaryCredential != ConfigBoolTrue { + logger.Debug("credentials storage is disabled, cannot use refresh tokens") + return nil + } + refreshTokenSpec := newOAuthRefreshTokenSpec(oauthClient.cfg.OauthTokenRequestURL, oauthClient.cfg.User) + refreshToken, _ := getCredentialFast(refreshTokenSpec) + if refreshToken == "" { + logger.Debug("no refresh token in cache, full flow must be run") + return nil + } body := url.Values{} body.Add("grant_type", "refresh_token") body.Add("refresh_token", refreshToken) @@ -417,24 +417,24 @@ func (oauthClient *oauthClient) refreshToken(lease *Lease) error { logger.Warnf("error while closing response body for %v. %v", req.URL, err) } }() - if resp.StatusCode != 200 { - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - _ = deleteCredentialWithLease(refreshTokenSpec) - return errors.New(string(respBody)) - } + if resp.StatusCode != 200 { + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + _ = deleteCredentialWithLease(refreshTokenSpec) + return errors.New(string(respBody)) + } var tokenResponse tokenExchangeResponseBody if err = json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil { return err } - accessTokenSpec := oauthClient.accessTokenSpec() - _ = setCredentialWithLease(accessTokenSpec, tokenResponse.AccessToken) - if tokenResponse.RefreshToken != "" { - _ = setCredentialWithLease(refreshTokenSpec, tokenResponse.RefreshToken) - } - return nil + accessTokenSpec := oauthClient.accessTokenSpec() + _ = setCredentialWithLease(accessTokenSpec, tokenResponse.AccessToken) + if tokenResponse.RefreshToken != "" { + _ = setCredentialWithLease(refreshTokenSpec, tokenResponse.RefreshToken) + } + return nil } type tokenExchangeResponseBody struct { diff --git a/connection.go b/connection.go index 246239a03..8164488d2 100644 --- a/connection.go +++ b/connection.go @@ -565,7 +565,7 @@ func (sc *snowflakeConn) QueryArrowStream(ctx context.Context, query string, bin JSON: data.Data.RowSet, RowSetBase64: data.Data.RowSetBase64, }, - queryID: data.Data.QueryID, + queryID: data.Data.QueryID, resultIDs: resultIDs, } // if multistatement is used, we need to set the first result set to actual result set, not the aggregated response diff --git a/value_awaiter.go b/value_awaiter.go index 331c4bd8d..14dde4b07 100644 --- a/value_awaiter.go +++ b/value_awaiter.go @@ -61,7 +61,7 @@ func awaitValue[T any](valueAwaiter *valueAwaiterType, runFunc func() (T, error) return defaultFactoryFunc(), nil } } - + // Value is ready - all threads should return the value. logger.Tracef("awaitValue[%v] value was ready after wait", goroutineID()) valueAwaiter.mu.Unlock() From be4147127919cd207bc7d105a484cb53ebae6267 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 25 Nov 2025 01:22:11 -0800 Subject: [PATCH 3/4] Centralize cache behavior + defensive lease handling --- auth.go | 53 ++++++++++------ auth_oauth.go | 20 +++--- secure_storage_manager.go | 125 +++++++++++++++++++++----------------- 3 files changed, 113 insertions(+), 85 deletions(-) diff --git a/auth.go b/auth.go index 10aad43cc..1c77dff9d 100644 --- a/auth.go +++ b/auth.go @@ -431,13 +431,13 @@ func authenticate( logger.WithContext(ctx).Errorln("Authentication FAILED") sc.rest.TokenAccessor.SetTokens("", "", -1) if sessionParameters[clientRequestMfaToken] == true { - _ = deleteCredentialWithLease(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User)) + _ = credentialsStorage.deleteCredential(lease, newMfaTokenSpec(sc.cfg.Host, sc.cfg.User)) } if sessionParameters[clientStoreTemporaryCredential] == true && sc.cfg.Authenticator == AuthTypeExternalBrowser { - _ = deleteCredentialWithLease(newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) + _ = credentialsStorage.deleteCredential(lease, newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) } if sessionParameters[clientStoreTemporaryCredential] == true && sc.cfg.Authenticator.isOauthNativeFlow() { - _ = deleteCredentialWithLease(newOAuthAccessTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) + _ = credentialsStorage.deleteCredential(lease, newOAuthAccessTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) } code, err := strconv.Atoi(respd.Code) if err != nil { @@ -454,7 +454,7 @@ func authenticate( if sessionParameters[clientRequestMfaToken] == true && sc.cfg.Authenticator == AuthTypeUsernamePasswordMFA { token := respd.Data.MfaToken - _ = setCredentialWithLease(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User), token) + _ = credentialsStorage.setCredential(lease, newMfaTokenSpec(sc.cfg.Host, sc.cfg.User), token) } if sessionParameters[clientStoreTemporaryCredential] == true && sc.cfg.Authenticator == AuthTypeExternalBrowser { @@ -462,7 +462,7 @@ func authenticate( // XXX: for some reason, token is empty here some times and we // don't want to clear the cache, so let's skip it if it's empty if token != "" { - _ = setCredentialWithLease(newIDTokenSpec(sc.cfg.Host, sc.cfg.User), token) + _ = credentialsStorage.setCredential(lease, newIDTokenSpec(sc.cfg.Host, sc.cfg.User), token) } } return &respd.Data, nil @@ -634,7 +634,7 @@ func authenticateByAuthorizationCode(sc *snowflakeConn, lease *Lease) (string, e valueAwaiter := valueAwaitHolder.get(lockKey) defer valueAwaiter.resumeOne() token, err := awaitValue(valueAwaiter, func() (string, error) { - return getCredentialFast(newOAuthAccessTokenSpec(oauthClient.tokenURL(), sc.cfg.User)) + return credentialsStorage.getCredentialRelaxed(lease, newOAuthAccessTokenSpec(oauthClient.tokenURL(), sc.cfg.User)) }, func(s string, err error) bool { return s != "" }, func() string { @@ -734,6 +734,13 @@ func extBrowserBackoffKey(host, user string) string { // Authenticate with sc.cfg func authenticateWithConfig(sc *snowflakeConn) error { + return doAuthenticateWithConfig(sc, nil) +} + +// Perform authentication and consult caches +// First, attempt a relaxed read without a lease +// If this is unsuccessful, recurse once with a lease +func doAuthenticateWithConfig(sc *snowflakeConn, optionalLease *Lease) error { var authData *authResponseMain var samlResponse []byte var proofKey []byte @@ -742,11 +749,13 @@ func authenticateWithConfig(sc *snowflakeConn) error { mfaTokenLockKey := newMfaTokenLockKey(sc.cfg.Host, sc.cfg.User) idTokenLockKey := newIDTokenLockKey(sc.cfg.Host, sc.cfg.User) - // Avoid acquiring the global lease up front; use fast-path lookups that - // consult the in-memory cache first and only acquire the lease on misses. - var lease *Lease + // Lease for coordination - may be provided or acquired locally + var lease *Lease = optionalLease + leaseAcquiredLocally := false + + // Release lease at the end if we acquired it locally defer func() { - if lease != nil { + if leaseAcquiredLocally && lease != nil { lease.Release() } }() @@ -765,7 +774,7 @@ func authenticateWithConfig(sc *snowflakeConn) error { sc.cfg.IDToken, _ = awaitValue( valueAwaiter, func() (string, error) { - tok, _ := getCredentialFast(newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) + tok, _ := credentialsStorage.getCredentialRelaxed(lease, newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) return tok, nil }, func(s string, err error) bool { @@ -777,7 +786,7 @@ func authenticateWithConfig(sc *snowflakeConn) error { ) } else if sc.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { - tok, _ := getCredentialFast(newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) + tok, _ := credentialsStorage.getCredentialRelaxed(lease, newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) sc.cfg.IDToken = tok } // Disable console login by default @@ -796,7 +805,7 @@ func authenticateWithConfig(sc *snowflakeConn) error { sc.cfg.MfaToken, _ = awaitValue( valueAwaiter, func() (string, error) { - tok, err := getCredentialFast(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User)) + tok, err := credentialsStorage.getCredentialRelaxed(lease, newMfaTokenSpec(sc.cfg.Host, sc.cfg.User)) if err != nil { logger.WithContext(sc.ctx).Warnf("failed to get MFA token from credential storage: %v", err) } @@ -808,7 +817,7 @@ func authenticateWithConfig(sc *snowflakeConn) error { }, ) } else if sc.cfg.ClientRequestMfaToken == ConfigBoolTrue { - tok, err := getCredentialFast(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User)) + tok, err := credentialsStorage.getCredentialRelaxed(lease, newMfaTokenSpec(sc.cfg.Host, sc.cfg.User)) if err != nil { logger.WithContext(sc.ctx).Warnf("failed to get MFA token from credential storage: %v", err) } @@ -832,14 +841,18 @@ func authenticateWithConfig(sc *snowflakeConn) error { lastFail.Delete(key) } - // External browser path requires a live lease for periodic renewals + // External browser path requires a lease for renewals + // When using parallel login coordination, we need the lease but cannot recurse + // due to the valueAwaiter if lease == nil { lease, err = credentialsStorage.acquireLease() if err != nil { sc.cleanup() return err } + leaseAcquiredLocally = true } + samlResponse, proofKey, err = authenticateByExternalBrowser( sc.ctx, lease, @@ -856,6 +869,7 @@ func authenticateWithConfig(sc *snowflakeConn) error { } } } + authData, err = authenticate( sc.ctx, lease, @@ -873,7 +887,7 @@ func authenticateWithConfig(sc *snowflakeConn) error { switch { // Case 1: cached ID token failed -> clear + try one interactive refresh case sc.cfg.Authenticator == AuthTypeExternalBrowser && sc.cfg.IDToken != "": - _ = deleteCredentialWithLease(newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) + _ = credentialsStorage.deleteCredential(lease, newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) sc.cfg.IDToken = "" // NOTE on tabstorms: @@ -892,6 +906,7 @@ func authenticateWithConfig(sc *snowflakeConn) error { sc.cleanup() return err } + leaseAcquiredLocally = true } samlResponse, proofKey, err = authenticateByExternalBrowser( sc.ctx, lease, sc.rest, sc.cfg.Authenticator.String(), @@ -912,7 +927,7 @@ func authenticateWithConfig(sc *snowflakeConn) error { // Case 2: still failing, but could be an OAuth refreshable error case errors.As(err, &se) && slices.Contains(refreshOAuthTokenErrorCodes, strconv.Itoa(se.Number)): - _ = deleteCredentialWithLease(newOAuthAccessTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) + _ = credentialsStorage.deleteCredential(lease, newOAuthAccessTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) if sc.cfg.Authenticator == AuthTypeOAuthAuthorizationCode { doRefreshTokenWithLease(sc, lease) @@ -960,7 +975,7 @@ func doRefreshTokenWithLock(sc *snowflakeConn, lease *Lease) { if _, err = getValueWithLock(chooseLockerForAuth(sc.cfg), lockKey, func() (string, error) { if err = oauthClient.refreshToken(lease); err != nil { logger.Warnf("cannot refresh token. %v", err) - _ = deleteCredentialWithLease(newOAuthRefreshTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) + _ = credentialsStorage.deleteCredential(lease, newOAuthRefreshTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) return "", err } return "", nil @@ -976,7 +991,7 @@ func doRefreshTokenWithLease(sc *snowflakeConn, lease *Lease) { } else { if rfErr := oauthClient.refreshToken(lease); rfErr != nil { logger.Warnf("cannot refresh token. %v", rfErr) - _ = deleteCredentialWithLease(newOAuthRefreshTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) + _ = credentialsStorage.deleteCredential(lease, newOAuthRefreshTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) } } diff --git a/auth_oauth.go b/auth_oauth.go index 3180c7147..9b1d9b720 100644 --- a/auth_oauth.go +++ b/auth_oauth.go @@ -93,11 +93,11 @@ type oauthBrowserResult struct { func (oauthClient *oauthClient) authenticateByOAuthAuthorizationCode(lease *Lease) (string, error) { accessTokenSpec := oauthClient.accessTokenSpec() if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { - if accessToken, err := getCredentialFast(accessTokenSpec); accessToken != "" { + if accessToken, err := credentialsStorage.getCredentialRelaxed(lease, accessTokenSpec); accessToken != "" { logger.Debugf("Access token retrieved from cache: %v", err) return accessToken, nil } - if refreshToken, _ := getCredentialFast(oauthClient.refreshTokenSpec()); refreshToken != "" { + if refreshToken, _ := credentialsStorage.getCredentialRelaxed(lease, oauthClient.refreshTokenSpec()); refreshToken != "" { return "", &SnowflakeError{Number: ErrMissingAccessATokenButRefreshTokenPresent} } } @@ -123,8 +123,8 @@ func (oauthClient *oauthClient) authenticateByOAuthAuthorizationCode(lease *Leas case result := <-resultChan: if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { logger.Debug("saving oauth access token in cache") - _ = setCredentialWithLease(oauthClient.accessTokenSpec(), result.accessToken) - _ = setCredentialWithLease(oauthClient.refreshTokenSpec(), result.refreshToken) + _ = credentialsStorage.setCredential(lease, oauthClient.accessTokenSpec(), result.accessToken) + _ = credentialsStorage.setCredential(lease, oauthClient.refreshTokenSpec(), result.refreshToken) } return result.accessToken, result.err } @@ -354,7 +354,7 @@ func (provider *browserBasedAuthorizationCodeProvider) createCodeVerifier() stri func (oauthClient *oauthClient) authenticateByOAuthClientCredentials(lease *Lease) (string, error) { accessTokenSpec := oauthClient.accessTokenSpec() if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { - if accessToken, err := getCredentialFast(accessTokenSpec); accessToken != "" { + if accessToken, err := credentialsStorage.getCredentialRelaxed(lease, accessTokenSpec); accessToken != "" { if err != nil { return "", fmt.Errorf("error retrieving access token from cache: %w", err) } @@ -370,7 +370,7 @@ func (oauthClient *oauthClient) authenticateByOAuthClientCredentials(lease *Leas return "", err } if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { - _ = setCredentialWithLease(accessTokenSpec, token.AccessToken) + _ = credentialsStorage.setCredential(lease, accessTokenSpec, token.AccessToken) } return token.AccessToken, nil } @@ -393,7 +393,7 @@ func (oauthClient *oauthClient) refreshToken(lease *Lease) error { return nil } refreshTokenSpec := newOAuthRefreshTokenSpec(oauthClient.cfg.OauthTokenRequestURL, oauthClient.cfg.User) - refreshToken, _ := getCredentialFast(refreshTokenSpec) + refreshToken, _ := credentialsStorage.getCredentialRelaxed(lease, refreshTokenSpec) if refreshToken == "" { logger.Debug("no refresh token in cache, full flow must be run") return nil @@ -422,7 +422,7 @@ func (oauthClient *oauthClient) refreshToken(lease *Lease) error { if err != nil { return err } - _ = deleteCredentialWithLease(refreshTokenSpec) + _ = credentialsStorage.deleteCredential(lease, refreshTokenSpec) return errors.New(string(respBody)) } var tokenResponse tokenExchangeResponseBody @@ -430,9 +430,9 @@ func (oauthClient *oauthClient) refreshToken(lease *Lease) error { return err } accessTokenSpec := oauthClient.accessTokenSpec() - _ = setCredentialWithLease(accessTokenSpec, tokenResponse.AccessToken) + _ = credentialsStorage.setCredential(lease, accessTokenSpec, tokenResponse.AccessToken) if tokenResponse.RefreshToken != "" { - _ = setCredentialWithLease(refreshTokenSpec, tokenResponse.RefreshToken) + _ = credentialsStorage.setCredential(lease, refreshTokenSpec, tokenResponse.RefreshToken) } return nil } diff --git a/secure_storage_manager.go b/secure_storage_manager.go index 164dc2e09..aa951a5d9 100644 --- a/secure_storage_manager.go +++ b/secure_storage_manager.go @@ -115,64 +115,12 @@ type secureStorageManager interface { acquireLease() (*Lease, error) setCredential(lease *Lease, tokenSpec *secureTokenSpec, value string) error getCredential(lease *Lease, tokenSpec *secureTokenSpec) (string, error) + getCredentialRelaxed(optionalLease *Lease, tokenSpec *secureTokenSpec) (string, error) deleteCredential(lease *Lease, tokenSpec *secureTokenSpec) error } var credentialsStorage = newSecureStorageManager() -// Helper fast-paths to minimize lease contention for common operations. -// These acquire a lease only if needed. - -// getCredentialFast attempts an in-memory read first (if available), -// otherwise acquires a lease to consult the persistent cache. -func getCredentialFast(tokenSpec *secureTokenSpec) (string, error) { - // Try in-memory if file-based storage - if fb, ok := credentialsStorage.(*fileBasedSecureStorageManager); ok { - key, err := tokenSpec.buildKey() - if err != nil { - return "", err - } - fb.memMu.RLock() - if v, ok := fb.mem[key]; ok { - fb.memMu.RUnlock() - return v, nil - } - fb.memMu.RUnlock() - } - - // Miss: acquire lease and consult backing store - lease, err := credentialsStorage.acquireLease() - if err != nil { - return "", err - } - if lease != nil { - defer lease.Release() - } - return credentialsStorage.getCredential(lease, tokenSpec) -} - -func setCredentialWithLease(tokenSpec *secureTokenSpec, value string) error { - lease, err := credentialsStorage.acquireLease() - if err != nil { - return err - } - if lease != nil { - defer lease.Release() - } - return credentialsStorage.setCredential(lease, tokenSpec, value) -} - -func deleteCredentialWithLease(tokenSpec *secureTokenSpec) error { - lease, err := credentialsStorage.acquireLease() - if err != nil { - return err - } - if lease != nil { - defer lease.Release() - } - return credentialsStorage.deleteCredential(lease, tokenSpec) -} - func newSecureStorageManager() secureStorageManager { var ssm secureStorageManager var err error @@ -365,7 +313,17 @@ func (ssm *fileBasedSecureStorageManager) setCredential(lease *Lease, tokenSpec ssm.memMu.Unlock() } - return ssm.withCacheFile(lease, func(cacheFile *os.File) error { + // Acquire lease if not provided + var actualLease = lease + if lease == nil { + actualLease, err = ssm.acquireLease() + if err != nil { + return err + } + defer actualLease.Release() + } + + return ssm.withCacheFile(actualLease, func(cacheFile *os.File) error { credCache, err := ssm.readTemporaryCacheFile(cacheFile) if err != nil { logger.Warnf("Error while reading cache file: %v", err) @@ -457,6 +415,37 @@ func (ssm *fileBasedSecureStorageManager) unlockFile() { } } +// Attempts an in-memory read first, +// Otherwise consults the persisted cache with a lease +func (ssm *fileBasedSecureStorageManager) getCredentialRelaxed(optionalLease *Lease, tokenSpec *secureTokenSpec) (string, error) { + key, err := tokenSpec.buildKey() + if err != nil { + return "", err + } + + ssm.memMu.RLock() + if v, ok := ssm.mem[key]; ok { + ssm.memMu.RUnlock() + return v, nil + } + ssm.memMu.RUnlock() + + // Miss: consult backing store + var lease = optionalLease + if optionalLease == nil { + lease, err = ssm.acquireLease() + if err != nil { + return "", err + } + } + + if lease != nil { + defer lease.Release() + } + + return ssm.getCredential(lease, tokenSpec) +} + func (ssm *fileBasedSecureStorageManager) getCredential(lease *Lease, tokenSpec *secureTokenSpec) (string, error) { credentialsKey, err := tokenSpec.buildKey() if err != nil { @@ -471,8 +460,18 @@ func (ssm *fileBasedSecureStorageManager) getCredential(lease *Lease, tokenSpec } ssm.memMu.RUnlock() + // Acquire lease if not provided + var actualLease = lease + if lease == nil { + actualLease, err = ssm.acquireLease() + if err != nil { + return "", err + } + defer actualLease.Release() + } + ret := "" - err = ssm.withCacheFile(lease, func(cacheFile *os.File) error { + err = ssm.withCacheFile(actualLease, func(cacheFile *os.File) error { credCache, err := ssm.readTemporaryCacheFile(cacheFile) if err != nil { logger.Warnf("Error while reading cache file. %v", err) @@ -586,7 +585,17 @@ func (ssm *fileBasedSecureStorageManager) deleteCredential(lease *Lease, tokenSp delete(ssm.mem, credentialsKey) ssm.memMu.Unlock() - return ssm.withCacheFile(lease, func(cacheFile *os.File) error { + // Acquire lease if not provided + var actualLease = lease + if lease == nil { + actualLease, err = ssm.acquireLease() + if err != nil { + return err + } + defer actualLease.Release() + } + + return ssm.withCacheFile(actualLease, func(cacheFile *os.File) error { credCache, err := ssm.readTemporaryCacheFile(cacheFile) if err != nil { logger.Warnf("Error while reading cache file. %v", err) @@ -758,6 +767,10 @@ func (ssm *noopSecureStorageManager) getCredential(_ *Lease, _ *secureTokenSpec) return "", nil // no-op implementation for secure storage manager } +func (ssm *noopSecureStorageManager) getCredentialRelaxed(_ *Lease, _ *secureTokenSpec) (string, error) { + return "", nil +} + func (ssm *noopSecureStorageManager) deleteCredential(_ *Lease, _ *secureTokenSpec) error { return nil } From 11195ef7cf46ab82aae2d3db0e52a3cbcc1e2511 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 25 Nov 2025 01:30:16 -0800 Subject: [PATCH 4/4] Optional lease --- auth.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/auth.go b/auth.go index 1c77dff9d..dcceae9f6 100644 --- a/auth.go +++ b/auth.go @@ -339,7 +339,7 @@ func getHeaders() map[string]string { // Used to authenticate the user with Snowflake. func authenticate( ctx context.Context, - lease *Lease, + optionalLease *Lease, sc *snowflakeConn, samlResponse []byte, proofKey []byte, @@ -403,7 +403,7 @@ func authenticate( sessionParameters[clientStoreTemporaryCredential] = true } bodyCreator := func() ([]byte, error) { - return createRequestBody(sc, lease, sessionParameters, clientEnvironment, proofKey, samlResponse) + return createRequestBody(sc, optionalLease, sessionParameters, clientEnvironment, proofKey, samlResponse) } params := &url.Values{} @@ -431,13 +431,13 @@ func authenticate( logger.WithContext(ctx).Errorln("Authentication FAILED") sc.rest.TokenAccessor.SetTokens("", "", -1) if sessionParameters[clientRequestMfaToken] == true { - _ = credentialsStorage.deleteCredential(lease, newMfaTokenSpec(sc.cfg.Host, sc.cfg.User)) + _ = credentialsStorage.deleteCredential(optionalLease, newMfaTokenSpec(sc.cfg.Host, sc.cfg.User)) } if sessionParameters[clientStoreTemporaryCredential] == true && sc.cfg.Authenticator == AuthTypeExternalBrowser { - _ = credentialsStorage.deleteCredential(lease, newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) + _ = credentialsStorage.deleteCredential(optionalLease, newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) } if sessionParameters[clientStoreTemporaryCredential] == true && sc.cfg.Authenticator.isOauthNativeFlow() { - _ = credentialsStorage.deleteCredential(lease, newOAuthAccessTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) + _ = credentialsStorage.deleteCredential(optionalLease, newOAuthAccessTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) } code, err := strconv.Atoi(respd.Code) if err != nil { @@ -454,7 +454,7 @@ func authenticate( if sessionParameters[clientRequestMfaToken] == true && sc.cfg.Authenticator == AuthTypeUsernamePasswordMFA { token := respd.Data.MfaToken - _ = credentialsStorage.setCredential(lease, newMfaTokenSpec(sc.cfg.Host, sc.cfg.User), token) + _ = credentialsStorage.setCredential(optionalLease, newMfaTokenSpec(sc.cfg.Host, sc.cfg.User), token) } if sessionParameters[clientStoreTemporaryCredential] == true && sc.cfg.Authenticator == AuthTypeExternalBrowser { @@ -462,7 +462,7 @@ func authenticate( // XXX: for some reason, token is empty here some times and we // don't want to clear the cache, so let's skip it if it's empty if token != "" { - _ = credentialsStorage.setCredential(lease, newIDTokenSpec(sc.cfg.Host, sc.cfg.User), token) + _ = credentialsStorage.setCredential(optionalLease, newIDTokenSpec(sc.cfg.Host, sc.cfg.User), token) } } return &respd.Data, nil