From 26771015e689a18a2935cdf08c2b14fc50c6c6d3 Mon Sep 17 00:00:00 2001 From: Anders Wallin Date: Tue, 14 May 2024 10:55:04 +0700 Subject: [PATCH] Added CheckRedirect support to Client. --- client.go | 3 ++- client_test.go | 62 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index 79dc931..94fb900 100644 --- a/client.go +++ b/client.go @@ -769,6 +769,7 @@ func (c *Client) PostForm(url string, data url.Values) (*http.Response, error) { // shims in a *retryablehttp.Client for added retries. func (c *Client) StandardClient() *http.Client { return &http.Client{ - Transport: &RoundTripper{Client: c}, + Transport: &RoundTripper{Client: c}, + CheckRedirect: c.HTTPClient.CheckRedirect, } } diff --git a/client_test.go b/client_test.go index b8a3d9d..f91bf79 100644 --- a/client_test.go +++ b/client_test.go @@ -648,6 +648,68 @@ func TestClient_CheckRetryStop(t *testing.T) { } } +func TestClient_CheckRedirects(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Location", "/new/path") + + switch r.URL.Path { + case "/301": + w.WriteHeader(301) + case "/302": + w.WriteHeader(302) + default: + w.WriteHeader(500) + } + })) + defer ts.Close() + + client := NewClient() + client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + stdClient := client.StandardClient() + + tests := []int{301, 302} + + // Check that we get 301 and 302 responses. + for _, test := range tests { + resp, err := client.Get(fmt.Sprintf("%s/%d", ts.URL, test)) + if err != nil { + t.Fatalf("unexpected error testing check redirect. %s", err.Error()) + } + if resp.StatusCode != test { + t.Fatalf("expected status code %d but got %d", test, resp.StatusCode) + } + + // Check with standard client as well. + resp, err = stdClient.Get(fmt.Sprintf("%s/%d", ts.URL, test)) + if err != nil { + t.Fatalf("unexpected error testing check redirect. %s", err.Error()) + } + if resp.StatusCode != test { + t.Fatalf("expected status code %d but got %d", test, resp.StatusCode) + } + } + + // Check that we get errors when using default check redirect policy. + client = NewClient() + client.RetryMax = 0 + stdClient = client.StandardClient() + + for _, test := range tests { + _, err := client.Get(fmt.Sprintf("%s/%d", ts.URL, test)) + if err == nil { + t.Fatalf("expected none nil error when testing default redirect behavior") + } + + // Check with standard client as well. + _, err = stdClient.Get(fmt.Sprintf("%s/%d", ts.URL, test)) + if err == nil { + t.Fatalf("expected none nil error when testing default redirect behavior") + } + } +} + func TestClient_Head(t *testing.T) { // Mock server which always responds 200. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {