diff --git a/transport.go b/transport.go index 8bbebbac9..6064877a9 100644 --- a/transport.go +++ b/transport.go @@ -48,6 +48,9 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { } req2 := req.Clone(req.Context()) + if req2.Header == nil { + req2.Header = make(http.Header) + } token.SetAuthHeader(req2) // req.Body is assumed to be closed by the base RoundTripper. diff --git a/transport_test.go b/transport_test.go index a8e6ea236..7733293c7 100644 --- a/transport_test.go +++ b/transport_test.go @@ -5,6 +5,7 @@ import ( "io" "net/http" "net/http/httptest" + "net/url" "testing" "time" ) @@ -154,3 +155,51 @@ func TestExpiredWithExpiry(t *testing.T) { func newMockServer(handler func(w http.ResponseWriter, r *http.Request)) *httptest.Server { return httptest.NewServer(http.HandlerFunc(handler)) } + +// TestTransportWithNilHeader tests that the Transport.RoundTrip method +// correctly handles requests with nil Headers. +func TestTransportWithNilHeader(t *testing.T) { + // Create a mock token source that returns a fixed token + tokenSource := StaticTokenSource(&Token{ + AccessToken: "test-access-token", + TokenType: "Bearer", + }) + + // Create a mock http server to verify the request + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check that the Authorization header was correctly set + authHeader := r.Header.Get("Authorization") + expectedHeader := "Bearer test-access-token" + if authHeader != expectedHeader { + t.Errorf("expected authorization header %q, got %q", expectedHeader, authHeader) + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Create Transport with our token source + transport := &Transport{ + Source: tokenSource, + Base: http.DefaultTransport, + } + + // Create a request with nil Header + reqURL, _ := url.Parse(server.URL) + req := &http.Request{ + Method: "GET", + URL: reqURL, + // Header is intentionally nil + } + + // Make the request using our Transport + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("roundTrip failed with nil Header: %v", err) + } + defer resp.Body.Close() + + // Verify response status + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status code %d, got %d", http.StatusOK, resp.StatusCode) + } +}