diff --git a/registry/tokentransport.go b/registry/tokentransport.go index 4e858b04..5d8cb614 100644 --- a/registry/tokentransport.go +++ b/registry/tokentransport.go @@ -1,8 +1,11 @@ package registry import ( + "bytes" "encoding/json" "fmt" + "io" + "io/ioutil" "net/http" "net/url" ) @@ -13,16 +16,47 @@ type TokenTransport struct { Password string } -func (t *TokenTransport) RoundTrip(req *http.Request) (*http.Response, error) { - resp, err := t.Transport.RoundTrip(req) +func (t *TokenTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) { + + err = t.complementRequestGetBodyFunction(req) + if err != nil { + return + } + + resp, err = t.Transport.RoundTrip(req) if err != nil { - return resp, err + return } + if authService := isTokenDemand(resp); authService != nil { - resp.Body.Close() + _ = resp.Body.Close() + if req.GetBody != nil { + req.Body, err = req.GetBody() + if err != nil { + return + } + } resp, err = t.authAndRetry(authService, req) } - return resp, err + return +} + +func (t *TokenTransport) complementRequestGetBodyFunction(req *http.Request) (err error) { + if req.GetBody != nil || req.Body == nil { + return + } + + var snapshot []byte + snapshot, err = ioutil.ReadAll(req.Body) + if err != nil { + return + } + + req.Body = ioutil.NopCloser(bytes.NewReader(snapshot)) + req.GetBody = func() (io.ReadCloser, error) { + return ioutil.NopCloser(bytes.NewReader(snapshot)), nil + } + return } type authToken struct {