From 6b30711c387499854cf24786c64ea3f2f2a4785a Mon Sep 17 00:00:00 2001
From: Alexander Lyashko <aurx@mail.ru>
Date: Fri, 19 Apr 2019 10:18:20 +0300
Subject: [PATCH] errors: return all token fetch related errors as structured

Right now all non-HTTP errors are returned as plain error with combined
text messages into single string, e.g.:
oauth2: cannot fetch token: Post https://oauth2.googleapis.com/token: read tcp 10.0.6.7:46650->172.217.212.95:443: read: connection reset by peer

So caller doesn't have a chance to check easily if underlying error was
retriable (e.g. net.IsTemporary() or DNS error) or not.

This patch wraps ALL errors as RetrieveError with additional Err field
in case it was non-http error.
---
 internal/token.go | 13 +++++++++----
 jira/jira.go      |  8 ++++----
 jwt/jwt.go        |  6 +++---
 token.go          | 11 +++++++++--
 4 files changed, 25 insertions(+), 13 deletions(-)

diff --git a/internal/token.go b/internal/token.go
index 355c38696..bc5dd4c04 100644
--- a/internal/token.go
+++ b/internal/token.go
@@ -231,12 +231,12 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
 func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
 	r, err := ctxhttp.Do(ctx, ContextClient(ctx), req)
 	if err != nil {
-		return nil, err
+		return nil, &RetrieveError{Err: err}
 	}
 	body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20))
 	r.Body.Close()
 	if err != nil {
-		return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
+		return nil, &RetrieveError{Err: err}
 	}
 	if code := r.StatusCode; code < 200 || code > 299 {
 		return nil, &RetrieveError{
@@ -251,7 +251,7 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
 	case "application/x-www-form-urlencoded", "text/plain":
 		vals, err := url.ParseQuery(string(body))
 		if err != nil {
-			return nil, err
+			return nil, &RetrieveError{Err: err}
 		}
 		token = &Token{
 			AccessToken:  vals.Get("access_token"),
@@ -285,10 +285,15 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
 }
 
 type RetrieveError struct {
+	Err      error
 	Response *http.Response
 	Body     []byte
 }
 
 func (r *RetrieveError) Error() string {
-	return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body)
+	if r.Err != nil {
+		return fmt.Sprintf("oauth2: cannot fetch token: %v", r.Err)
+	} else {
+		return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body)
+	}
 }
diff --git a/jira/jira.go b/jira/jira.go
index 34415607c..ee047146e 100644
--- a/jira/jira.go
+++ b/jira/jira.go
@@ -111,15 +111,15 @@ func (js jwtSource) Token() (*oauth2.Token, error) {
 	hc := oauth2.NewClient(js.ctx, nil)
 	resp, err := hc.PostForm(js.conf.Endpoint.TokenURL, v)
 	if err != nil {
-		return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
+		return nil, &oauth2.RetrieveError{Err: err}
 	}
 	defer resp.Body.Close()
 	body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
 	if err != nil {
-		return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
+		return nil, &oauth2.RetrieveError{Err: err}
 	}
 	if c := resp.StatusCode; c < 200 || c > 299 {
-		return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", resp.Status, body)
+		return nil, &oauth2.RetrieveError{Response: resp, Body: body}
 	}
 
 	// tokenRes is the JSON response body.
@@ -129,7 +129,7 @@ func (js jwtSource) Token() (*oauth2.Token, error) {
 		ExpiresIn   int64  `json:"expires_in"` // relative seconds from now
 	}
 	if err := json.Unmarshal(body, &tokenRes); err != nil {
-		return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
+		return nil, &oauth2.RetrieveError{Err: err}
 	}
 	token := &oauth2.Token{
 		AccessToken: tokenRes.AccessToken,
diff --git a/jwt/jwt.go b/jwt/jwt.go
index 99f3e0a32..f24f48b07 100644
--- a/jwt/jwt.go
+++ b/jwt/jwt.go
@@ -124,12 +124,12 @@ func (js jwtSource) Token() (*oauth2.Token, error) {
 	v.Set("assertion", payload)
 	resp, err := hc.PostForm(js.conf.TokenURL, v)
 	if err != nil {
-		return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
+		return nil, &oauth2.RetrieveError{Err: err}
 	}
 	defer resp.Body.Close()
 	body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
 	if err != nil {
-		return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
+		return nil, &oauth2.RetrieveError{Err: err}
 	}
 	if c := resp.StatusCode; c < 200 || c > 299 {
 		return nil, &oauth2.RetrieveError{
@@ -145,7 +145,7 @@ func (js jwtSource) Token() (*oauth2.Token, error) {
 		ExpiresIn   int64  `json:"expires_in"` // relative seconds from now
 	}
 	if err := json.Unmarshal(body, &tokenRes); err != nil {
-		return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
+		return nil, &oauth2.RetrieveError{Err: err}
 	}
 	token := &oauth2.Token{
 		AccessToken: tokenRes.AccessToken,
diff --git a/token.go b/token.go
index 822720341..3ba7367e9 100644
--- a/token.go
+++ b/token.go
@@ -165,8 +165,11 @@ func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error)
 }
 
 // RetrieveError is the error returned when the token endpoint returns a
-// non-2XX HTTP status code.
+// non-2XX HTTP status code or fails with non-HTTP error
 type RetrieveError struct {
+	// Err is the (net) error happened during HTTP request when response not received
+	Err error
+
 	Response *http.Response
 	// Body is the body that was consumed by reading Response.Body.
 	// It may be truncated.
@@ -174,5 +177,9 @@ type RetrieveError struct {
 }
 
 func (r *RetrieveError) Error() string {
-	return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body)
+	if r.Err != nil {
+		return fmt.Sprintf("oauth2: cannot fetch token: %v", r.Err)
+	} else {
+		return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body)
+	}
 }