Skip to content

Commit

Permalink
Adds missing early returns (#35)
Browse files Browse the repository at this point in the history
Adds missing early returns on unexpected measurements headers along a regression test.
  • Loading branch information
Ruteri authored Feb 20, 2025
1 parent ca38dda commit 7947875
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
2 changes: 2 additions & 0 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,11 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Note: the reverse proxy adds X-Forwarded-For header!
if r.Header.Get(MeasurementHeader) != "" {
http.Error(w, "unexpected measurement header passed", http.StatusForbidden)
return
}
if r.Header.Get(AttestationTypeHeader) != "" {
http.Error(w, "unexpected attestation type header passed", http.StatusForbidden)
return
}

if r.TLS != nil {
Expand Down
57 changes: 57 additions & 0 deletions proxy/proxy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package proxy

import (
"io"
"log/slog"
"net/http"
"net/http/httptest"
"testing"

"github.com/flashbots/cvm-reverse-proxy/common"
"github.com/flashbots/cvm-reverse-proxy/internal/atls"
"github.com/stretchr/testify/require"
)

func getTestLogger() *slog.Logger {
return common.SetupLogger(&common.LoggingOpts{
Debug: true,
JSON: false,
Service: "test",
Version: "test",
})
}

func Test_Handlers_Healthcheck_Drain_Undrain(t *testing.T) {
testEchoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "checkcheck")
}))
defer testEchoServer.Close()

proxy := NewProxy(getTestLogger(), testEchoServer.URL, []atls.Validator{})

{ // Check green path
req := httptest.NewRequest(http.MethodGet, "http://proxyhost.should.not.matter/", nil) //nolint:goconst,nolintlint
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
resp := w.Result()
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode, "Must return `Ok`")
require.Equal(t, []byte("checkcheck"), respBody)
}

// Check failure if measurement header is present
for _, header := range []string{AttestationTypeHeader, MeasurementHeader} {
req := httptest.NewRequest(http.MethodGet, "http://proxyhost.should.not.matter/", nil) //nolint:goconst,nolintlint
req.Header.Add(header, "xx")
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
resp := w.Result()
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, http.StatusForbidden, resp.StatusCode, "Must return `Forbidden` on measurements header")
require.Contains(t, string(respBody), "unexpected")
}
}

0 comments on commit 7947875

Please sign in to comment.