diff --git a/clientcredentials/clientcredentials.go b/clientcredentials/clientcredentials.go index 7a0b9ed10..b62dca277 100644 --- a/clientcredentials/clientcredentials.go +++ b/clientcredentials/clientcredentials.go @@ -74,20 +74,25 @@ func (c *Config) Client(ctx context.Context) *http.Client { // Most users will use Config.Client instead. func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource { source := &tokenSource{ - ctx: ctx, - conf: c, + httpclient: internal.ContextClient(ctx), + conf: c, } return oauth2.ReuseTokenSource(nil, source) } type tokenSource struct { - ctx context.Context - conf *Config + httpclient *http.Client + conf *Config } // Token refreshes the token by using a new client credentials request. // tokens received this way do not include a refresh token func (c *tokenSource) Token() (*oauth2.Token, error) { + return c.TokenContext(context.Background()) +} + +// TokenContext implements oauth2.tokenSourceContext. +func (c *tokenSource) TokenContext(ctx context.Context) (*oauth2.Token, error) { v := url.Values{ "grant_type": {"client_credentials"}, } @@ -103,7 +108,8 @@ func (c *tokenSource) Token() (*oauth2.Token, error) { v[k] = p } - tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle)) + ctx = context.WithValue(ctx, oauth2.HTTPClient, c.httpclient) + tk, err := internal.RetrieveToken(ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle)) if err != nil { if rErr, ok := err.(*internal.RetrieveError); ok { return nil, (*oauth2.RetrieveError)(rErr) diff --git a/oauth2.go b/oauth2.go index 291df5c83..c72124097 100644 --- a/oauth2.go +++ b/oauth2.go @@ -67,6 +67,10 @@ type TokenSource interface { Token() (*Token, error) } +type tokenSourceContext interface { + TokenContext(context.Context) (*Token, error) +} + // Endpoint represents an OAuth 2.0 provider's authorization and token // endpoint URLs. type Endpoint struct { @@ -296,15 +300,31 @@ type reuseTokenSource struct { // refresh the current token (using r.Context for HTTP client // information) and return the new one. func (s *reuseTokenSource) Token() (*Token, error) { + return s.TokenContext(context.Background()) +} + +// TokenContext implements tokenSourceContext. +func (s *reuseTokenSource) TokenContext(ctx context.Context) (*Token, error) { s.mu.Lock() defer s.mu.Unlock() if s.t.Valid() { return s.t, nil } - t, err := s.new.Token() + + var ( + t *Token + err error + ) + tsctx, ok := s.new.(tokenSourceContext) + if ok { + t, err = tsctx.TokenContext(ctx) + } else { + t, err = s.new.Token() + } if err != nil { return nil, err } + s.t = t return t, nil } diff --git a/transport.go b/transport.go index aa0d34f1e..f20525946 100644 --- a/transport.go +++ b/transport.go @@ -45,7 +45,17 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { if t.Source == nil { return nil, errors.New("oauth2: Transport's Source is nil") } - token, err := t.Source.Token() + + var ( + token *Token + err error + ) + tsctx, ok := t.Source.(tokenSourceContext) + if ok { + token, err = tsctx.TokenContext(req.Context()) + } else { + token, err = t.Source.Token() + } if err != nil { return nil, err }