Skip to content

Commit 7947875

Browse files
authored
Adds missing early returns (#35)
Adds missing early returns on unexpected measurements headers along a regression test.
1 parent ca38dda commit 7947875

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

proxy/proxy.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,11 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
8484
// Note: the reverse proxy adds X-Forwarded-For header!
8585
if r.Header.Get(MeasurementHeader) != "" {
8686
http.Error(w, "unexpected measurement header passed", http.StatusForbidden)
87+
return
8788
}
8889
if r.Header.Get(AttestationTypeHeader) != "" {
8990
http.Error(w, "unexpected attestation type header passed", http.StatusForbidden)
91+
return
9092
}
9193

9294
if r.TLS != nil {

proxy/proxy_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package proxy
2+
3+
import (
4+
"io"
5+
"log/slog"
6+
"net/http"
7+
"net/http/httptest"
8+
"testing"
9+
10+
"github.com/flashbots/cvm-reverse-proxy/common"
11+
"github.com/flashbots/cvm-reverse-proxy/internal/atls"
12+
"github.com/stretchr/testify/require"
13+
)
14+
15+
func getTestLogger() *slog.Logger {
16+
return common.SetupLogger(&common.LoggingOpts{
17+
Debug: true,
18+
JSON: false,
19+
Service: "test",
20+
Version: "test",
21+
})
22+
}
23+
24+
func Test_Handlers_Healthcheck_Drain_Undrain(t *testing.T) {
25+
testEchoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
26+
_, _ = io.WriteString(w, "checkcheck")
27+
}))
28+
defer testEchoServer.Close()
29+
30+
proxy := NewProxy(getTestLogger(), testEchoServer.URL, []atls.Validator{})
31+
32+
{ // Check green path
33+
req := httptest.NewRequest(http.MethodGet, "http://proxyhost.should.not.matter/", nil) //nolint:goconst,nolintlint
34+
w := httptest.NewRecorder()
35+
proxy.ServeHTTP(w, req)
36+
resp := w.Result()
37+
defer resp.Body.Close()
38+
respBody, err := io.ReadAll(resp.Body)
39+
require.NoError(t, err)
40+
require.Equal(t, http.StatusOK, resp.StatusCode, "Must return `Ok`")
41+
require.Equal(t, []byte("checkcheck"), respBody)
42+
}
43+
44+
// Check failure if measurement header is present
45+
for _, header := range []string{AttestationTypeHeader, MeasurementHeader} {
46+
req := httptest.NewRequest(http.MethodGet, "http://proxyhost.should.not.matter/", nil) //nolint:goconst,nolintlint
47+
req.Header.Add(header, "xx")
48+
w := httptest.NewRecorder()
49+
proxy.ServeHTTP(w, req)
50+
resp := w.Result()
51+
defer resp.Body.Close()
52+
respBody, err := io.ReadAll(resp.Body)
53+
require.NoError(t, err)
54+
require.Equal(t, http.StatusForbidden, resp.StatusCode, "Must return `Forbidden` on measurements header")
55+
require.Contains(t, string(respBody), "unexpected")
56+
}
57+
}

0 commit comments

Comments
 (0)