diff --git a/go/http/echo/middleware.go b/go/http/echo/middleware.go index 83bb5a66ec..89d8b09d75 100644 --- a/go/http/echo/middleware.go +++ b/go/http/echo/middleware.go @@ -53,11 +53,15 @@ func (a *EchoAdapter) GetPath() string { func (a *EchoAdapter) GetURL() string { req := a.ctx.Request() scheme := "http" - if req.TLS != nil { + if xForwardedProto := req.Header.Get("X-Forwarded-Proto"); xForwardedProto != "" { + scheme = xForwardedProto + } else if req.TLS != nil { scheme = "https" } host := req.Host - if host == "" { + if xForwardedHost := req.Header.Get("X-Forwarded-Host"); xForwardedHost != "" { + host = xForwardedHost + } else if host == "" { host = req.Header.Get("Host") } return fmt.Sprintf("%s://%s%s", scheme, host, req.RequestURI) diff --git a/go/http/echo/middleware_test.go b/go/http/echo/middleware_test.go index 674c5c03b7..d6f3db3cc5 100644 --- a/go/http/echo/middleware_test.go +++ b/go/http/echo/middleware_test.go @@ -212,6 +212,11 @@ func TestEchoAdapter_GetURL(t *testing.T) { target: "/api/test?id=1&foo=bar", expected: "http://example.com/api/test?id=1&foo=bar", }, + { + name: "x forwarded for", + target: "/api/test", + expected: "https://example.com/api/test", + }, } for _, tt := range tests { @@ -226,6 +231,10 @@ func TestEchoAdapter_GetURL(t *testing.T) { req := httptest.NewRequest("GET", tt.target, nil) req.Host = "example.com" + if tt.name == "x forwarded for" { + req.Header.Set("X-Forwarded-Proto", "https") + req.Header.Set("X-Forwarded-Host", "example.com") + } w := httptest.NewRecorder() e.ServeHTTP(w, req) diff --git a/go/http/gin/middleware.go b/go/http/gin/middleware.go index 9e6ae00f01..9df9e469ff 100644 --- a/go/http/gin/middleware.go +++ b/go/http/gin/middleware.go @@ -52,11 +52,15 @@ func (a *GinAdapter) GetPath() string { // GetURL gets the full request URL func (a *GinAdapter) GetURL() string { scheme := "http" - if a.ctx.Request.TLS != nil { + if xForwardedProto := a.ctx.GetHeader("X-Forwarded-Proto"); xForwardedProto != "" { + scheme = xForwardedProto + } else if a.ctx.Request.TLS != nil { scheme = "https" } host := a.ctx.Request.Host - if host == "" { + if xForwardedHost := a.ctx.GetHeader("X-Forwarded-Host"); xForwardedHost != "" { + host = xForwardedHost + } else if host == "" { host = a.ctx.GetHeader("Host") } return fmt.Sprintf("%s://%s%s", scheme, host, a.ctx.Request.URL.RequestURI()) diff --git a/go/http/gin/middleware_test.go b/go/http/gin/middleware_test.go index ae66119737..5595e5bd39 100644 --- a/go/http/gin/middleware_test.go +++ b/go/http/gin/middleware_test.go @@ -220,6 +220,11 @@ func TestGinAdapter_GetURL(t *testing.T) { target: "/api/test?id=1&foo=bar", expected: "http://example.com/api/test?id=1&foo=bar", }, + { + name: "x forwarded for", + target: "/api/test", + expected: "https://example.com/api/test", + }, } for _, tt := range tests { @@ -233,6 +238,10 @@ func TestGinAdapter_GetURL(t *testing.T) { req := httptest.NewRequest("GET", tt.target, nil) req.Host = "example.com" + if tt.name == "x forwarded for" { + req.Header.Set("X-Forwarded-Proto", "https") + req.Header.Set("X-Forwarded-Host", "example.com") + } w := httptest.NewRecorder() router.ServeHTTP(w, req) diff --git a/go/http/nethttp/adapter.go b/go/http/nethttp/adapter.go index 8aa0888571..8ad3ab3b2d 100644 --- a/go/http/nethttp/adapter.go +++ b/go/http/nethttp/adapter.go @@ -33,11 +33,15 @@ func (a *NetHTTPAdapter) GetPath() string { // GetURL gets the full request URL. func (a *NetHTTPAdapter) GetURL() string { scheme := "http" - if a.r.TLS != nil { + if xForwardedProto := a.r.Header.Get("X-Forwarded-Proto"); xForwardedProto != "" { + scheme = xForwardedProto + } else if a.r.TLS != nil { scheme = "https" } host := a.r.Host - if host == "" { + if xForwardedHost := a.r.Header.Get("X-Forwarded-Host"); xForwardedHost != "" { + host = xForwardedHost + } else if host == "" { host = a.r.Header.Get("Host") } return fmt.Sprintf("%s://%s%s", scheme, host, a.r.URL.RequestURI()) diff --git a/go/http/nethttp/middleware_test.go b/go/http/nethttp/middleware_test.go index 9c4d98666f..eadc28e6f3 100644 --- a/go/http/nethttp/middleware_test.go +++ b/go/http/nethttp/middleware_test.go @@ -191,6 +191,11 @@ func TestNetHTTPAdapter_GetURL(t *testing.T) { target: "/api/test?id=1&foo=bar", expected: "http://example.com/api/test?id=1&foo=bar", }, + { + name: "x forwarded for", + target: "/api/test", + expected: "https://example.com/api/test", + }, } for _, tt := range tests { @@ -199,6 +204,11 @@ func TestNetHTTPAdapter_GetURL(t *testing.T) { req.Host = "example.com" adapter := NewNetHTTPAdapter(req) + if tt.name == "x forwarded for" { + req.Header.Set("X-Forwarded-Proto", "https") + req.Header.Set("X-Forwarded-Host", "example.com") + } + if adapter.GetURL() != tt.expected { t.Errorf("Expected URL '%s', got '%s'", tt.expected, adapter.GetURL()) }