Skip to content

Commit a835fc4

Browse files
committed
oauth2: move global auth style cache to be per-Config
In 80673b4 (https://go.dev/cl/157820) I added a never-shrinking package-global cache to remember which auto-detected auth style (HTTP headers vs POST) was supported by a certain OAuth2 server, keyed by its URL. Unfortunately, some multi-tenant SaaS OIDC servers behave poorly and have one global OpenID configuration document for all of their customers which says ("we support all auth styles! you pick!") but then give each customer control of which style they specifically accept. This is bogus behavior on their part, but the oauth2 package's global caching per URL isn't helping. (It's also bad to have a package-global cache that can never be GC'ed) So, this change moves the cache to hang off the oauth *Configs instead. Unfortunately, it does so with some backwards compatiblity compromises (an atomic.Value hack), lest people are using old versions of Go still or copying a Config by value, both of which this package previously accidentally supported, even though they weren't tested. This change also means that anybody that's repeatedly making ephemeral oauth.Configs without an explicit auth style will be losing & reinitializing their cache on any auth style failures + fallbacks to the other style. I think that should be pretty rare. People seem to make an oauth2.Config once earlier and stash it away somewhere (often deep in a token fetcher or HTTP client/transport). Change-Id: I91f107368ab3c3d77bc425eeef65372a589feb7b Signed-off-by: Brad Fitzpatrick <[email protected]> Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/515675 TryBot-Result: Gopher Robot <[email protected]> Reviewed-by: Roland Shoemaker <[email protected]> Reviewed-by: Adrian Dewhurst <[email protected]> Reviewed-by: Michael Knyszek <[email protected]>
1 parent 2e4a4e2 commit a835fc4

File tree

7 files changed

+60
-39
lines changed

7 files changed

+60
-39
lines changed

clientcredentials/clientcredentials.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ type Config struct {
4747
// client ID & client secret sent. The zero value means to
4848
// auto-detect.
4949
AuthStyle oauth2.AuthStyle
50+
51+
// authStyleCache caches which auth style to use when Endpoint.AuthStyle is
52+
// the zero value (AuthStyleAutoDetect).
53+
authStyleCache internal.LazyAuthStyleCache
5054
}
5155

5256
// Token uses client credentials to retrieve a token.
@@ -103,7 +107,7 @@ func (c *tokenSource) Token() (*oauth2.Token, error) {
103107
v[k] = p
104108
}
105109

106-
tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle))
110+
tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle), c.conf.authStyleCache.Get())
107111
if err != nil {
108112
if rErr, ok := err.(*internal.RetrieveError); ok {
109113
return nil, (*oauth2.RetrieveError)(rErr)

clientcredentials/clientcredentials_test.go

-3
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ import (
1212
"net/http/httptest"
1313
"net/url"
1414
"testing"
15-
16-
"golang.org/x/oauth2/internal"
1715
)
1816

1917
func newConf(serverURL string) *Config {
@@ -114,7 +112,6 @@ func TestTokenRequest(t *testing.T) {
114112
}
115113

116114
func TestTokenRefreshRequest(t *testing.T) {
117-
internal.ResetAuthCache()
118115
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
119116
if r.URL.String() == "/somethingelse" {
120117
return

internal/token.go

+45-25
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"strconv"
1919
"strings"
2020
"sync"
21+
"sync/atomic"
2122
"time"
2223
)
2324

@@ -115,41 +116,60 @@ const (
115116
AuthStyleInHeader AuthStyle = 2
116117
)
117118

118-
// authStyleCache is the set of tokenURLs we've successfully used via
119+
// LazyAuthStyleCache is a backwards compatibility compromise to let Configs
120+
// have a lazily-initialized AuthStyleCache.
121+
//
122+
// The two users of this, oauth2.Config and oauth2/clientcredentials.Config,
123+
// both would ideally just embed an unexported AuthStyleCache but because both
124+
// were historically allowed to be copied by value we can't retroactively add an
125+
// uncopyable Mutex to them.
126+
//
127+
// We could use an atomic.Pointer, but that was added recently enough (in Go
128+
// 1.18) that we'd break Go 1.17 users where the tests as of 2023-08-03
129+
// still pass. By using an atomic.Value, it supports both Go 1.17 and
130+
// copying by value, even if that's not ideal.
131+
type LazyAuthStyleCache struct {
132+
v atomic.Value // of *AuthStyleCache
133+
}
134+
135+
func (lc *LazyAuthStyleCache) Get() *AuthStyleCache {
136+
if c, ok := lc.v.Load().(*AuthStyleCache); ok {
137+
return c
138+
}
139+
c := new(AuthStyleCache)
140+
if !lc.v.CompareAndSwap(nil, c) {
141+
c = lc.v.Load().(*AuthStyleCache)
142+
}
143+
return c
144+
}
145+
146+
// AuthStyleCache is the set of tokenURLs we've successfully used via
119147
// RetrieveToken and which style auth we ended up using.
120148
// It's called a cache, but it doesn't (yet?) shrink. It's expected that
121149
// the set of OAuth2 servers a program contacts over time is fixed and
122150
// small.
123-
var authStyleCache struct {
124-
sync.Mutex
125-
m map[string]AuthStyle // keyed by tokenURL
126-
}
127-
128-
// ResetAuthCache resets the global authentication style cache used
129-
// for AuthStyleUnknown token requests.
130-
func ResetAuthCache() {
131-
authStyleCache.Lock()
132-
defer authStyleCache.Unlock()
133-
authStyleCache.m = nil
151+
type AuthStyleCache struct {
152+
mu sync.Mutex
153+
m map[string]AuthStyle // keyed by tokenURL
134154
}
135155

136156
// lookupAuthStyle reports which auth style we last used with tokenURL
137157
// when calling RetrieveToken and whether we have ever done so.
138-
func lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) {
139-
authStyleCache.Lock()
140-
defer authStyleCache.Unlock()
141-
style, ok = authStyleCache.m[tokenURL]
158+
func (c *AuthStyleCache) lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) {
159+
c.mu.Lock()
160+
defer c.mu.Unlock()
161+
style, ok = c.m[tokenURL]
142162
return
143163
}
144164

145165
// setAuthStyle adds an entry to authStyleCache, documented above.
146-
func setAuthStyle(tokenURL string, v AuthStyle) {
147-
authStyleCache.Lock()
148-
defer authStyleCache.Unlock()
149-
if authStyleCache.m == nil {
150-
authStyleCache.m = make(map[string]AuthStyle)
166+
func (c *AuthStyleCache) setAuthStyle(tokenURL string, v AuthStyle) {
167+
c.mu.Lock()
168+
defer c.mu.Unlock()
169+
if c.m == nil {
170+
c.m = make(map[string]AuthStyle)
151171
}
152-
authStyleCache.m[tokenURL] = v
172+
c.m[tokenURL] = v
153173
}
154174

155175
// newTokenRequest returns a new *http.Request to retrieve a new token
@@ -189,10 +209,10 @@ func cloneURLValues(v url.Values) url.Values {
189209
return v2
190210
}
191211

192-
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle) (*Token, error) {
212+
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*Token, error) {
193213
needsAuthStyleProbe := authStyle == 0
194214
if needsAuthStyleProbe {
195-
if style, ok := lookupAuthStyle(tokenURL); ok {
215+
if style, ok := styleCache.lookupAuthStyle(tokenURL); ok {
196216
authStyle = style
197217
needsAuthStyleProbe = false
198218
} else {
@@ -222,7 +242,7 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
222242
token, err = doTokenRoundTrip(ctx, req)
223243
}
224244
if needsAuthStyleProbe && err == nil {
225-
setAuthStyle(tokenURL, authStyle)
245+
styleCache.setAuthStyle(tokenURL, authStyle)
226246
}
227247
// Don't overwrite `RefreshToken` with an empty value
228248
// if this was a token refreshing request.

internal/token_test.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import (
1616
)
1717

1818
func TestRetrieveToken_InParams(t *testing.T) {
19-
ResetAuthCache()
19+
styleCache := new(AuthStyleCache)
2020
const clientID = "client-id"
2121
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2222
if got, want := r.FormValue("client_id"), clientID; got != want {
@@ -29,14 +29,14 @@ func TestRetrieveToken_InParams(t *testing.T) {
2929
io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`)
3030
}))
3131
defer ts.Close()
32-
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleInParams)
32+
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleInParams, styleCache)
3333
if err != nil {
3434
t.Errorf("RetrieveToken = %v; want no error", err)
3535
}
3636
}
3737

3838
func TestRetrieveTokenWithContexts(t *testing.T) {
39-
ResetAuthCache()
39+
styleCache := new(AuthStyleCache)
4040
const clientID = "client-id"
4141

4242
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -45,7 +45,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) {
4545
}))
4646
defer ts.Close()
4747

48-
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleUnknown)
48+
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleUnknown, styleCache)
4949
if err != nil {
5050
t.Errorf("RetrieveToken (with background context) = %v; want no error", err)
5151
}
@@ -58,7 +58,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) {
5858

5959
ctx, cancel := context.WithCancel(context.Background())
6060
cancel()
61-
_, err = RetrieveToken(ctx, clientID, "", cancellingts.URL, url.Values{}, AuthStyleUnknown)
61+
_, err = RetrieveToken(ctx, clientID, "", cancellingts.URL, url.Values{}, AuthStyleUnknown, styleCache)
6262
close(retrieved)
6363
if err == nil {
6464
t.Errorf("RetrieveToken (with cancelled context) = nil; want error")

oauth2.go

+4
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ type Config struct {
5858

5959
// Scope specifies optional requested permissions.
6060
Scopes []string
61+
62+
// authStyleCache caches which auth style to use when Endpoint.AuthStyle is
63+
// the zero value (AuthStyleAutoDetect).
64+
authStyleCache internal.LazyAuthStyleCache
6165
}
6266

6367
// A TokenSource is anything that can return a token.

oauth2_test.go

-4
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ import (
1515
"net/url"
1616
"testing"
1717
"time"
18-
19-
"golang.org/x/oauth2/internal"
2018
)
2119

2220
type mockTransport struct {
@@ -355,7 +353,6 @@ func TestExchangeRequest_BadResponseType(t *testing.T) {
355353
}
356354

357355
func TestExchangeRequest_NonBasicAuth(t *testing.T) {
358-
internal.ResetAuthCache()
359356
tr := &mockTransport{
360357
rt: func(r *http.Request) (w *http.Response, err error) {
361358
headerAuth := r.Header.Get("Authorization")
@@ -427,7 +424,6 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) {
427424
}
428425

429426
func TestTokenRefreshRequest(t *testing.T) {
430-
internal.ResetAuthCache()
431427
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
432428
if r.URL.String() == "/somethingelse" {
433429
return

token.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ func tokenFromInternal(t *internal.Token) *Token {
164164
// This token is then mapped from *internal.Token into an *oauth2.Token which is returned along
165165
// with an error..
166166
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) {
167-
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle))
167+
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get())
168168
if err != nil {
169169
if rErr, ok := err.(*internal.RetrieveError); ok {
170170
return nil, (*RetrieveError)(rErr)

0 commit comments

Comments
 (0)