diff --git a/go.sum b/go.sum index 53a721d..9709fb7 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,7 @@ github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvK github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= diff --git a/graphqlws/http.go b/graphqlws/http.go index 73d9a2b..4d1e58c 100644 --- a/graphqlws/http.go +++ b/graphqlws/http.go @@ -2,7 +2,6 @@ package graphqlws import ( "context" - "fmt" "net/http" "time" @@ -11,13 +10,18 @@ import ( "github.com/graph-gophers/graphql-transport-ws/graphqlws/internal/connection" ) -// ProtocolGraphQLWS is websocket subprotocol ID for GraphQL over WebSocket -// see https://github.com/apollographql/subscriptions-transport-ws -const ProtocolGraphQLWS = "graphql-ws" +const ( + // ProtocolGraphQLTransportWS is the modern websocket subprotocol ID for GraphQL over WebSocket. + // see https://github.com/enisdenjo/graphql-ws + ProtocolGraphQLTransportWS = "graphql-transport-ws" + // ProtocolGraphQLWS is the deprecated websocket subprotocol ID for GraphQL over WebSocket. + // see https://github.com/apollographql/subscriptions-transport-ws + ProtocolGraphQLWS = "graphql-ws" +) var defaultUpgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, - Subprotocols: []string{ProtocolGraphQLWS}, + Subprotocols: []string{ProtocolGraphQLTransportWS, ProtocolGraphQLWS}, } type handler struct { @@ -55,7 +59,21 @@ type Option interface { type options struct { contextGenerators []ContextGenerator - connectOptions []connection.Option + readLimit int64 + writeTimeout time.Duration + hasReadLimit bool + hasWriteTimeout bool +} + +func (o *options) connectionOptions() []connection.Option { + var connOptions []connection.Option + if o.hasReadLimit { + connOptions = append(connOptions, connection.ReadLimit(o.readLimit)) + } + if o.hasWriteTimeout { + connOptions = append(connOptions, connection.WriteTimeout(o.writeTimeout)) + } + return connOptions } type optionFunc func(*options) @@ -75,16 +93,16 @@ func WithContextGenerator(f ContextGenerator) Option { // WithReadLimit limits the maximum size of incoming messages func WithReadLimit(limit int64) Option { return optionFunc(func(o *options) { - connOpt := connection.ReadLimit(limit) - o.connectOptions = append(o.connectOptions, connOpt) + o.readLimit = limit + o.hasReadLimit = true }) } // WithWriteTimeout sets a timeout for outgoing messages func WithWriteTimeout(d time.Duration) Option { return optionFunc(func(o *options) { - connOpt := connection.WriteTimeout(d) - o.connectOptions = append(o.connectOptions, connOpt) + o.writeTimeout = d + o.hasWriteTimeout = true }) } @@ -109,37 +127,42 @@ func (h *handler) NewHandlerFunc(svc connection.GraphQLService, httpHandler http o := applyOptions(options...) return func(w http.ResponseWriter, r *http.Request) { - for _, subprotocol := range websocket.Subprotocols(r) { - if subprotocol != ProtocolGraphQLWS { - continue - } - - ctx, err := buildContext(r, o.contextGenerators) - if err != nil { - w.Header().Set("X-WebSocket-Upgrade-Failure", err.Error()) + if !websocket.IsWebSocketUpgrade(r) { + if httpHandler == nil { + http.Error(w, "Not Found", http.StatusNotFound) return } - ws, err := h.Upgrader.Upgrade(w, r, nil) - if err != nil { - w.Header().Set("X-WebSocket-Upgrade-Failure", err.Error()) - return - } + httpHandler.ServeHTTP(w, r) + return + } - if ws.Subprotocol() != ProtocolGraphQLWS { - w.Header().Set("X-WebSocket-Upgrade-Failure", - fmt.Sprintf("upgraded websocket has wrong subprotocol (%s)", ws.Subprotocol())) - ws.Close() - return - } + ctx, err := buildContext(r, o.contextGenerators) + if err != nil { + w.Header().Set("X-WebSocket-Upgrade-Failure", err.Error()) + return + } - go connection.Connect(ctx, ws, svc, o.connectOptions...) + ws, err := h.Upgrader.Upgrade(w, r, nil) + if err != nil { + // UPGRADE FAILED: The Upgrader has already written an error response. + // Do not call the httpHandler return } - // Fallback to HTTP - w.Header().Set("X-WebSocket-Upgrade-Failure", "no subprotocols available") - httpHandler.ServeHTTP(w, r) + switch ws.Subprotocol() { + case ProtocolGraphQLWS: + go connection.Connect(ctx, ws, svc, o.connectionOptions()...) + + case ProtocolGraphQLTransportWS: + // TODO: support graphql-transport-ws subprotocol + w.Header().Set("X-WebSocket-Upgrade-Failure", "unsupported subprotocol") + ws.Close() + + default: + w.Header().Set("X-WebSocket-Upgrade-Failure", "unsupported subprotocol") + ws.Close() + } } } diff --git a/graphqlws/http_test.go b/graphqlws/http_test.go index 97d7654..58a38ec 100644 --- a/graphqlws/http_test.go +++ b/graphqlws/http_test.go @@ -4,12 +4,217 @@ import ( "context" "errors" "net/http" + "net/http/httptest" + "strings" "testing" + "time" + "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) +type mockGraphQLService struct { + mock.Mock +} + +func (s *mockGraphQLService) Subscribe(ctx context.Context, document string, operationName string, variableValues map[string]interface{}) (<-chan interface{}, error) { + args := s.Called(ctx, document, operationName, variableValues) + return args.Get(0).(<-chan interface{}), args.Error(1) +} + +type mockHTTPHandler struct { + mock.Mock +} + +func (h *mockHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h.Called(w, r) + w.WriteHeader(http.StatusOK) +} + +type testMocker struct { + handler http.Handler + mockSvc *mockGraphQLService + mockHTTP *mockHTTPHandler + done chan bool +} + +func TestNewHandlerFunc(t *testing.T) { + type Args struct { + isWebSocketTest bool + subprotocols []string + } + + type Want struct { + assertion func(t *testing.T, conn *websocket.Conn) + checkError func(t *testing.T, err error) + expectedSubprotocol string + } + + testTable := map[string]struct { + args Args + mockExpectations func(m testMocker) + setup func() testMocker + want Want + }{ + "legacy protocol ok": { + args: Args{ + isWebSocketTest: true, + subprotocols: []string{ProtocolGraphQLWS}, + }, + mockExpectations: func(m testMocker) { + c := make(chan interface{}) + close(c) + + m.mockSvc.On("Subscribe", mock.AnythingOfType("*context.valueCtx"), "subscription{}", "", (map[string]interface{})(nil)).Return((<-chan interface{})(c), nil).Run(func(args mock.Arguments) { + m.done <- true + }).Once() + }, + setup: func() testMocker { + mockSvc := new(mockGraphQLService) + return testMocker{ + handler: NewHandlerFunc(mockSvc, nil), + mockSvc: mockSvc, + done: make(chan bool, 1), + } + }, + want: Want{ + expectedSubprotocol: ProtocolGraphQLWS, + assertion: func(t *testing.T, conn *websocket.Conn) { + err := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"connection_init"}`)) + require.NoError(t, err) + + err = conn.WriteMessage(websocket.TextMessage, []byte(`{"id":"1","type":"start","payload":{"query":"subscription{}"}}`)) + require.NoError(t, err) + + _, _, err = conn.ReadMessage() + require.NoError(t, err) + }, + }, + }, + "graphql-transport-ws protocol unsupported ": { + args: Args{ + isWebSocketTest: true, + subprotocols: []string{ProtocolGraphQLTransportWS}, + }, + setup: func() testMocker { + return testMocker{handler: NewHandlerFunc(nil, nil)} + }, + want: Want{ + expectedSubprotocol: ProtocolGraphQLTransportWS, + assertion: func(t *testing.T, conn *websocket.Conn) { + var closeError *websocket.CloseError + + _, _, err := conn.ReadMessage() + assert.ErrorAs(t, err, &closeError, "Expected server to close connection for placeholder") + }, + }, + }, + "unsupported protocol error": { + args: Args{ + isWebSocketTest: true, + subprotocols: []string{"unsupported-protocol"}, + }, + setup: func() testMocker { + return testMocker{handler: NewHandlerFunc(nil, nil)} + }, + want: Want{ + assertion: func(t *testing.T, conn *websocket.Conn) { + assert.Equal(t, "", conn.Subprotocol(), "Expected no subprotocol to be selected") + + var closeError *websocket.CloseError + + _, _, err := conn.ReadMessage() + assert.ErrorAs(t, err, &closeError, "Expected server to close connection for unsupported protocol") + }, + }, + }, + "HTTP fallback ok": { + mockExpectations: func(m testMocker) { + m.mockHTTP.On("ServeHTTP", mock.Anything, mock.MatchedBy(func(r *http.Request) bool { return r.Method == http.MethodGet })).Run(func(args mock.Arguments) { + m.done <- true + }).Once() + }, + setup: func() testMocker { + mockHTTP := new(mockHTTPHandler) + return testMocker{ + handler: NewHandlerFunc(nil, mockHTTP), + mockHTTP: mockHTTP, + done: make(chan bool, 1), + } + }, + want: Want{ + assertion: func(t *testing.T, conn *websocket.Conn) {}, + }, + }, + } + + for name, tt := range testTable { + tt := tt + + t.Run(name, func(t *testing.T) { + t.Parallel() + + mocker := tt.setup() + if tt.mockExpectations != nil { + tt.mockExpectations(mocker) + } + + server := httptest.NewServer(mocker.handler) + defer server.Close() + + if !tt.args.isWebSocketTest { + _, err := http.Get(server.URL) + require.NoError(t, err) + + select { + case <-mocker.done: + case <-time.After(1 * time.Second): + t.Fatal("timed out waiting for ServeHTTP mock to be called") + } + + mocker.mockHTTP.AssertExpectations(t) + return + } + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + dialer := websocket.Dialer{Subprotocols: tt.args.subprotocols} + conn, _, err := dialer.Dial(wsURL, nil) + + if tt.want.checkError != nil { + tt.want.checkError(t, err) + return + } + + require.NoError(t, err) + defer conn.Close() + + assert.Equal(t, tt.want.expectedSubprotocol, conn.Subprotocol()) + + if tt.want.assertion != nil { + tt.want.assertion(t, conn) + } + + if mocker.done != nil { + select { + case <-mocker.done: + case <-time.After(1 * time.Second): + t.Fatal("timed out waiting for mock expectation to be met") + } + } + + if mocker.mockSvc != nil { + mocker.mockSvc.AssertExpectations(t) + } + + if mocker.mockHTTP != nil { + mocker.mockHTTP.AssertExpectations(t) + } + }) + } +} + func TestContextGenerators(t *testing.T) { contextBuilderFunc := func() ContextGeneratorFunc { return func(ctx context.Context, r *http.Request) (context.Context, error) { @@ -28,7 +233,7 @@ func TestContextGenerators(t *testing.T) { } type want struct { Context context.Context - Error string + Error string } testTable := map[string]struct { @@ -65,6 +270,5 @@ func TestContextGenerators(t *testing.T) { assert.Equal(t, tt.Want.Context, ctx, "New context generated") require.NoError(t, err, "Error generating context") }) - } -} \ No newline at end of file +}