Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 56 additions & 22 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ func getHeaders() map[string]string {
// Used to authenticate the user with Snowflake.
func authenticate(
ctx context.Context,
lease *Lease,
optionalLease *Lease,
sc *snowflakeConn,
samlResponse []byte,
proofKey []byte,
Expand Down Expand Up @@ -403,7 +403,7 @@ func authenticate(
sessionParameters[clientStoreTemporaryCredential] = true
}
bodyCreator := func() ([]byte, error) {
return createRequestBody(sc, lease, sessionParameters, clientEnvironment, proofKey, samlResponse)
return createRequestBody(sc, optionalLease, sessionParameters, clientEnvironment, proofKey, samlResponse)
}

params := &url.Values{}
Expand Down Expand Up @@ -431,13 +431,13 @@ func authenticate(
logger.WithContext(ctx).Errorln("Authentication FAILED")
sc.rest.TokenAccessor.SetTokens("", "", -1)
if sessionParameters[clientRequestMfaToken] == true {
credentialsStorage.deleteCredential(lease, newMfaTokenSpec(sc.cfg.Host, sc.cfg.User))
_ = credentialsStorage.deleteCredential(optionalLease, newMfaTokenSpec(sc.cfg.Host, sc.cfg.User))
}
if sessionParameters[clientStoreTemporaryCredential] == true && sc.cfg.Authenticator == AuthTypeExternalBrowser {
credentialsStorage.deleteCredential(lease, newIDTokenSpec(sc.cfg.Host, sc.cfg.User))
_ = credentialsStorage.deleteCredential(optionalLease, newIDTokenSpec(sc.cfg.Host, sc.cfg.User))
}
if sessionParameters[clientStoreTemporaryCredential] == true && sc.cfg.Authenticator.isOauthNativeFlow() {
credentialsStorage.deleteCredential(lease, newOAuthAccessTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User))
_ = credentialsStorage.deleteCredential(optionalLease, newOAuthAccessTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User))
}
code, err := strconv.Atoi(respd.Code)
if err != nil {
Expand All @@ -454,15 +454,15 @@ func authenticate(

if sessionParameters[clientRequestMfaToken] == true && sc.cfg.Authenticator == AuthTypeUsernamePasswordMFA {
token := respd.Data.MfaToken
credentialsStorage.setCredential(lease, newMfaTokenSpec(sc.cfg.Host, sc.cfg.User), token)
_ = credentialsStorage.setCredential(optionalLease, newMfaTokenSpec(sc.cfg.Host, sc.cfg.User), token)
}

if sessionParameters[clientStoreTemporaryCredential] == true && sc.cfg.Authenticator == AuthTypeExternalBrowser {
token := respd.Data.IDToken
// XXX: for some reason, token is empty here some times and we
// don't want to clear the cache, so let's skip it if it's empty
if token != "" {
credentialsStorage.setCredential(lease, newIDTokenSpec(sc.cfg.Host, sc.cfg.User), token)
_ = credentialsStorage.setCredential(optionalLease, newIDTokenSpec(sc.cfg.Host, sc.cfg.User), token)
}
}
return &respd.Data, nil
Expand Down Expand Up @@ -634,7 +634,7 @@ func authenticateByAuthorizationCode(sc *snowflakeConn, lease *Lease) (string, e
valueAwaiter := valueAwaitHolder.get(lockKey)
defer valueAwaiter.resumeOne()
token, err := awaitValue(valueAwaiter, func() (string, error) {
return credentialsStorage.getCredential(lease, newOAuthAccessTokenSpec(oauthClient.tokenURL(), sc.cfg.User))
return credentialsStorage.getCredentialRelaxed(lease, newOAuthAccessTokenSpec(oauthClient.tokenURL(), sc.cfg.User))
}, func(s string, err error) bool {
return s != ""
}, func() string {
Expand Down Expand Up @@ -734,6 +734,13 @@ func extBrowserBackoffKey(host, user string) string {

// Authenticate with sc.cfg
func authenticateWithConfig(sc *snowflakeConn) error {
return doAuthenticateWithConfig(sc, nil)
}

// Perform authentication and consult caches
// First, attempt a relaxed read without a lease
// If this is unsuccessful, recurse once with a lease
func doAuthenticateWithConfig(sc *snowflakeConn, optionalLease *Lease) error {
var authData *authResponseMain
var samlResponse []byte
var proofKey []byte
Expand All @@ -742,11 +749,16 @@ func authenticateWithConfig(sc *snowflakeConn) error {
mfaTokenLockKey := newMfaTokenLockKey(sc.cfg.Host, sc.cfg.User)
idTokenLockKey := newIDTokenLockKey(sc.cfg.Host, sc.cfg.User)

lease, err := credentialsStorage.acquireLease()
if err != nil {
return err
}
defer lease.Release()
// Lease for coordination - may be provided or acquired locally
var lease *Lease = optionalLease
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can keep mutating optionalLease. You have a variable named lease that is just as optional as optionalLease

leaseAcquiredLocally := false
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lease being nil indicates it's not acquired, right? With the bool we now have 4 different states to handle.


// Release lease at the end if we acquired it locally
defer func() {
if leaseAcquiredLocally && lease != nil {
lease.Release()
}
}()

key := extBrowserBackoffKey(sc.cfg.Host, sc.cfg.User)

Expand All @@ -762,7 +774,7 @@ func authenticateWithConfig(sc *snowflakeConn) error {
sc.cfg.IDToken, _ = awaitValue(
valueAwaiter,
func() (string, error) {
tok, _ := credentialsStorage.getCredential(lease, newIDTokenSpec(sc.cfg.Host, sc.cfg.User))
tok, _ := credentialsStorage.getCredentialRelaxed(lease, newIDTokenSpec(sc.cfg.Host, sc.cfg.User))
return tok, nil
},
func(s string, err error) bool {
Expand All @@ -774,7 +786,7 @@ func authenticateWithConfig(sc *snowflakeConn) error {
)

} else if sc.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue {
tok, _ := credentialsStorage.getCredential(lease, newIDTokenSpec(sc.cfg.Host, sc.cfg.User))
tok, _ := credentialsStorage.getCredentialRelaxed(lease, newIDTokenSpec(sc.cfg.Host, sc.cfg.User))
sc.cfg.IDToken = tok
}
// Disable console login by default
Expand All @@ -793,7 +805,7 @@ func authenticateWithConfig(sc *snowflakeConn) error {
sc.cfg.MfaToken, _ = awaitValue(
valueAwaiter,
func() (string, error) {
tok, err := credentialsStorage.getCredential(lease, newMfaTokenSpec(sc.cfg.Host, sc.cfg.User))
tok, err := credentialsStorage.getCredentialRelaxed(lease, newMfaTokenSpec(sc.cfg.Host, sc.cfg.User))
if err != nil {
logger.WithContext(sc.ctx).Warnf("failed to get MFA token from credential storage: %v", err)
}
Expand All @@ -805,7 +817,7 @@ func authenticateWithConfig(sc *snowflakeConn) error {
},
)
} else if sc.cfg.ClientRequestMfaToken == ConfigBoolTrue {
tok, err := credentialsStorage.getCredential(lease, newMfaTokenSpec(sc.cfg.Host, sc.cfg.User))
tok, err := credentialsStorage.getCredentialRelaxed(lease, newMfaTokenSpec(sc.cfg.Host, sc.cfg.User))
if err != nil {
logger.WithContext(sc.ctx).Warnf("failed to get MFA token from credential storage: %v", err)
}
Expand All @@ -829,6 +841,18 @@ func authenticateWithConfig(sc *snowflakeConn) error {
lastFail.Delete(key)
}

// External browser path requires a lease for renewals
// When using parallel login coordination, we need the lease but cannot recurse
// due to the valueAwaiter
if lease == nil {
lease, err = credentialsStorage.acquireLease()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
lease, err = credentialsStorage.acquireLease()
optionalLease, err = credentialsStorage.acquireLease()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You acquire a lease here and proceed with the authentication process, but since your decision to proceed with auth comes from a relaxed read, you can't be sure the credentials cache is still empty. Between the relaxed read and this lease acquisition, some other process might have populated the credentials cache.

That's why the double-checked locking pattern requires two checks: one relaxed check that is fast and then another check after the lease/lock acquisition.

if err != nil {
sc.cleanup()
return err
}
leaseAcquiredLocally = true
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optionalLease being non-nil speaks for itself

}

samlResponse, proofKey, err = authenticateByExternalBrowser(
sc.ctx,
lease,
Expand All @@ -845,6 +869,7 @@ func authenticateWithConfig(sc *snowflakeConn) error {
}
}
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

authData, err = authenticate(
sc.ctx,
lease,
Expand All @@ -862,7 +887,7 @@ func authenticateWithConfig(sc *snowflakeConn) error {
switch {
// Case 1: cached ID token failed -> clear + try one interactive refresh
case sc.cfg.Authenticator == AuthTypeExternalBrowser && sc.cfg.IDToken != "":
credentialsStorage.deleteCredential(lease, newIDTokenSpec(sc.cfg.Host, sc.cfg.User))
_ = credentialsStorage.deleteCredential(lease, newIDTokenSpec(sc.cfg.Host, sc.cfg.User))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't delete in the filesystem if the lease is nil. As an invariant, no mutation should derive from a relaxed read of the cached credentials.

sc.cfg.IDToken = ""

// NOTE on tabstorms:
Expand All @@ -874,6 +899,15 @@ func authenticateWithConfig(sc *snowflakeConn) error {
// considered low probability and low impact

// reenter interactive flow
if lease == nil {
lease, err = credentialsStorage.acquireLease()
if err != nil {
// let SAML-phase failure remain retryable; do not set backoff
sc.cleanup()
return err
}
leaseAcquiredLocally = true
}
samlResponse, proofKey, err = authenticateByExternalBrowser(
sc.ctx, lease, sc.rest, sc.cfg.Authenticator.String(),
sc.cfg.Application, sc.cfg.Account, sc.cfg.User,
Expand All @@ -891,9 +925,9 @@ func authenticateWithConfig(sc *snowflakeConn) error {
}
fallthrough

// Case 2: still failing, but could be an OAuth refreshable error
// Case 2: still failing, but could be an OAuth refreshable error
case errors.As(err, &se) && slices.Contains(refreshOAuthTokenErrorCodes, strconv.Itoa(se.Number)):
credentialsStorage.deleteCredential(lease, newOAuthAccessTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User))
_ = credentialsStorage.deleteCredential(lease, newOAuthAccessTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User))

if sc.cfg.Authenticator == AuthTypeOAuthAuthorizationCode {
doRefreshTokenWithLease(sc, lease)
Expand Down Expand Up @@ -941,7 +975,7 @@ func doRefreshTokenWithLock(sc *snowflakeConn, lease *Lease) {
if _, err = getValueWithLock(chooseLockerForAuth(sc.cfg), lockKey, func() (string, error) {
if err = oauthClient.refreshToken(lease); err != nil {
logger.Warnf("cannot refresh token. %v", err)
credentialsStorage.deleteCredential(lease, newOAuthRefreshTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User))
_ = credentialsStorage.deleteCredential(lease, newOAuthRefreshTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User))
return "", err
}
return "", nil
Expand All @@ -957,7 +991,7 @@ func doRefreshTokenWithLease(sc *snowflakeConn, lease *Lease) {
} else {
if rfErr := oauthClient.refreshToken(lease); rfErr != nil {
logger.Warnf("cannot refresh token. %v", rfErr)
credentialsStorage.deleteCredential(lease, newOAuthRefreshTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User))
_ = credentialsStorage.deleteCredential(lease, newOAuthRefreshTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User))
}
}

Expand Down
20 changes: 10 additions & 10 deletions auth_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ type oauthBrowserResult struct {
func (oauthClient *oauthClient) authenticateByOAuthAuthorizationCode(lease *Lease) (string, error) {
accessTokenSpec := oauthClient.accessTokenSpec()
if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue {
if accessToken, err := credentialsStorage.getCredential(lease, accessTokenSpec); accessToken != "" {
if accessToken, err := credentialsStorage.getCredentialRelaxed(lease, accessTokenSpec); accessToken != "" {
logger.Debugf("Access token retrieved from cache: %v", err)
return accessToken, nil
}
if refreshToken, _ := credentialsStorage.getCredential(lease, oauthClient.refreshTokenSpec()); refreshToken != "" {
if refreshToken, _ := credentialsStorage.getCredentialRelaxed(lease, oauthClient.refreshTokenSpec()); refreshToken != "" {
return "", &SnowflakeError{Number: ErrMissingAccessATokenButRefreshTokenPresent}
}
}
Expand All @@ -123,8 +123,8 @@ func (oauthClient *oauthClient) authenticateByOAuthAuthorizationCode(lease *Leas
case result := <-resultChan:
if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue {
logger.Debug("saving oauth access token in cache")
credentialsStorage.setCredential(lease, oauthClient.accessTokenSpec(), result.accessToken)
credentialsStorage.setCredential(lease, oauthClient.refreshTokenSpec(), result.refreshToken)
_ = credentialsStorage.setCredential(lease, oauthClient.accessTokenSpec(), result.accessToken)
_ = credentialsStorage.setCredential(lease, oauthClient.refreshTokenSpec(), result.refreshToken)
}
return result.accessToken, result.err
}
Expand Down Expand Up @@ -354,7 +354,7 @@ func (provider *browserBasedAuthorizationCodeProvider) createCodeVerifier() stri
func (oauthClient *oauthClient) authenticateByOAuthClientCredentials(lease *Lease) (string, error) {
accessTokenSpec := oauthClient.accessTokenSpec()
if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue {
if accessToken, err := credentialsStorage.getCredential(lease, accessTokenSpec); accessToken != "" {
if accessToken, err := credentialsStorage.getCredentialRelaxed(lease, accessTokenSpec); accessToken != "" {
if err != nil {
return "", fmt.Errorf("error retrieving access token from cache: %w", err)
}
Expand All @@ -370,7 +370,7 @@ func (oauthClient *oauthClient) authenticateByOAuthClientCredentials(lease *Leas
return "", err
}
if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue {
credentialsStorage.setCredential(lease, accessTokenSpec, token.AccessToken)
_ = credentialsStorage.setCredential(lease, accessTokenSpec, token.AccessToken)
}
return token.AccessToken, nil
}
Expand All @@ -393,7 +393,7 @@ func (oauthClient *oauthClient) refreshToken(lease *Lease) error {
return nil
}
refreshTokenSpec := newOAuthRefreshTokenSpec(oauthClient.cfg.OauthTokenRequestURL, oauthClient.cfg.User)
refreshToken, _ := credentialsStorage.getCredential(lease, refreshTokenSpec)
refreshToken, _ := credentialsStorage.getCredentialRelaxed(lease, refreshTokenSpec)
if refreshToken == "" {
logger.Debug("no refresh token in cache, full flow must be run")
return nil
Expand Down Expand Up @@ -422,17 +422,17 @@ func (oauthClient *oauthClient) refreshToken(lease *Lease) error {
if err != nil {
return err
}
credentialsStorage.deleteCredential(lease, refreshTokenSpec)
_ = credentialsStorage.deleteCredential(lease, refreshTokenSpec)
return errors.New(string(respBody))
}
var tokenResponse tokenExchangeResponseBody
if err = json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil {
return err
}
accessTokenSpec := oauthClient.accessTokenSpec()
credentialsStorage.setCredential(lease, accessTokenSpec, tokenResponse.AccessToken)
_ = credentialsStorage.setCredential(lease, accessTokenSpec, tokenResponse.AccessToken)
if tokenResponse.RefreshToken != "" {
credentialsStorage.setCredential(lease, refreshTokenSpec, tokenResponse.RefreshToken)
_ = credentialsStorage.setCredential(lease, refreshTokenSpec, tokenResponse.RefreshToken)
}
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ func (sc *snowflakeConn) QueryArrowStream(ctx context.Context, query string, bin
JSON: data.Data.RowSet,
RowSetBase64: data.Data.RowSetBase64,
},
queryID: data.Data.QueryID,
queryID: data.Data.QueryID,
resultIDs: resultIDs,
}
// if multistatement is used, we need to set the first result set to actual result set, not the aggregated response
Expand Down
Loading
Loading