diff --git a/go.sum b/go.sum index 53a721d..9709fb7 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,7 @@ github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvK github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= diff --git a/graphqlws/http.go b/graphqlws/http.go index 73d9a2b..ccb8c14 100644 --- a/graphqlws/http.go +++ b/graphqlws/http.go @@ -2,22 +2,27 @@ package graphqlws import ( "context" - "fmt" "net/http" "time" "github.com/gorilla/websocket" "github.com/graph-gophers/graphql-transport-ws/graphqlws/internal/connection" + "github.com/graph-gophers/graphql-transport-ws/graphqlws/internal/gql" ) -// ProtocolGraphQLWS is websocket subprotocol ID for GraphQL over WebSocket -// see https://github.com/apollographql/subscriptions-transport-ws -const ProtocolGraphQLWS = "graphql-ws" +const ( + // ProtocolGraphQLTransportWS is the modern websocket subprotocol ID for GraphQL over WebSocket. + // see https://github.com/enisdenjo/graphql-ws + ProtocolGraphQLTransportWS = "graphql-transport-ws" + // ProtocolGraphQLWS is the deprecated websocket subprotocol ID for GraphQL over WebSocket. + // see https://github.com/apollographql/subscriptions-transport-ws + ProtocolGraphQLWS = "graphql-ws" +) var defaultUpgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, - Subprotocols: []string{ProtocolGraphQLWS}, + Subprotocols: []string{ProtocolGraphQLTransportWS, ProtocolGraphQLWS}, } type handler struct { @@ -55,7 +60,21 @@ type Option interface { type options struct { contextGenerators []ContextGenerator - connectOptions []connection.Option + readLimit int64 + writeTimeout time.Duration + hasReadLimit bool + hasWriteTimeout bool +} + +func (o *options) connectionOptions() []connection.Option { + var connOptions []connection.Option + if o.hasReadLimit { + connOptions = append(connOptions, connection.ReadLimit(o.readLimit)) + } + if o.hasWriteTimeout { + connOptions = append(connOptions, connection.WriteTimeout(o.writeTimeout)) + } + return connOptions } type optionFunc func(*options) @@ -75,16 +94,16 @@ func WithContextGenerator(f ContextGenerator) Option { // WithReadLimit limits the maximum size of incoming messages func WithReadLimit(limit int64) Option { return optionFunc(func(o *options) { - connOpt := connection.ReadLimit(limit) - o.connectOptions = append(o.connectOptions, connOpt) + o.readLimit = limit + o.hasReadLimit = true }) } // WithWriteTimeout sets a timeout for outgoing messages func WithWriteTimeout(d time.Duration) Option { return optionFunc(func(o *options) { - connOpt := connection.WriteTimeout(d) - o.connectOptions = append(o.connectOptions, connOpt) + o.writeTimeout = d + o.hasWriteTimeout = true }) } @@ -99,47 +118,52 @@ func applyOptions(opts ...Option) *options { } // NewHandlerFunc returns an http.HandlerFunc that supports GraphQL over websockets -func NewHandlerFunc(svc connection.GraphQLService, httpHandler http.Handler, options ...Option) http.HandlerFunc { +func NewHandlerFunc(svc gql.GraphQLService, httpHandler http.Handler, options ...Option) http.HandlerFunc { handler := NewHandler() return handler.NewHandlerFunc(svc, httpHandler, options...) } // NewHandlerFunc returns an http.HandlerFunc that supports GraphQL over websockets -func (h *handler) NewHandlerFunc(svc connection.GraphQLService, httpHandler http.Handler, options ...Option) http.HandlerFunc { +func (h *handler) NewHandlerFunc(svc gql.GraphQLService, httpHandler http.Handler, options ...Option) http.HandlerFunc { o := applyOptions(options...) return func(w http.ResponseWriter, r *http.Request) { - for _, subprotocol := range websocket.Subprotocols(r) { - if subprotocol != ProtocolGraphQLWS { - continue - } - - ctx, err := buildContext(r, o.contextGenerators) - if err != nil { - w.Header().Set("X-WebSocket-Upgrade-Failure", err.Error()) + if !websocket.IsWebSocketUpgrade(r) { + if httpHandler == nil { + http.Error(w, "Not Found", http.StatusNotFound) return } - ws, err := h.Upgrader.Upgrade(w, r, nil) - if err != nil { - w.Header().Set("X-WebSocket-Upgrade-Failure", err.Error()) - return - } + httpHandler.ServeHTTP(w, r) + return + } - if ws.Subprotocol() != ProtocolGraphQLWS { - w.Header().Set("X-WebSocket-Upgrade-Failure", - fmt.Sprintf("upgraded websocket has wrong subprotocol (%s)", ws.Subprotocol())) - ws.Close() - return - } + ctx, err := buildContext(r, o.contextGenerators) + if err != nil { + w.Header().Set("X-WebSocket-Upgrade-Failure", err.Error()) + return + } - go connection.Connect(ctx, ws, svc, o.connectOptions...) + ws, err := h.Upgrader.Upgrade(w, r, nil) + if err != nil { + // UPGRADE FAILED: The Upgrader has already written an error response. + // Do not call the httpHandler return } - // Fallback to HTTP - w.Header().Set("X-WebSocket-Upgrade-Failure", "no subprotocols available") - httpHandler.ServeHTTP(w, r) + switch ws.Subprotocol() { + case ProtocolGraphQLWS: + go connection.Connect(ctx, ws, svc, o.connectionOptions()...) + + case ProtocolGraphQLTransportWS: + // TODO: support graphql-transport-ws subprotocol + w.Header().Set("X-WebSocket-Upgrade-Failure", "unsupported subprotocol") + ws.Close() + + default: + w.Header().Set("X-WebSocket-Upgrade-Failure", "unsupported subprotocol") + ws.Close() + } } } diff --git a/graphqlws/http_test.go b/graphqlws/http_test.go index 97d7654..58a38ec 100644 --- a/graphqlws/http_test.go +++ b/graphqlws/http_test.go @@ -4,12 +4,217 @@ import ( "context" "errors" "net/http" + "net/http/httptest" + "strings" "testing" + "time" + "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) +type mockGraphQLService struct { + mock.Mock +} + +func (s *mockGraphQLService) Subscribe(ctx context.Context, document string, operationName string, variableValues map[string]interface{}) (<-chan interface{}, error) { + args := s.Called(ctx, document, operationName, variableValues) + return args.Get(0).(<-chan interface{}), args.Error(1) +} + +type mockHTTPHandler struct { + mock.Mock +} + +func (h *mockHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h.Called(w, r) + w.WriteHeader(http.StatusOK) +} + +type testMocker struct { + handler http.Handler + mockSvc *mockGraphQLService + mockHTTP *mockHTTPHandler + done chan bool +} + +func TestNewHandlerFunc(t *testing.T) { + type Args struct { + isWebSocketTest bool + subprotocols []string + } + + type Want struct { + assertion func(t *testing.T, conn *websocket.Conn) + checkError func(t *testing.T, err error) + expectedSubprotocol string + } + + testTable := map[string]struct { + args Args + mockExpectations func(m testMocker) + setup func() testMocker + want Want + }{ + "legacy protocol ok": { + args: Args{ + isWebSocketTest: true, + subprotocols: []string{ProtocolGraphQLWS}, + }, + mockExpectations: func(m testMocker) { + c := make(chan interface{}) + close(c) + + m.mockSvc.On("Subscribe", mock.AnythingOfType("*context.valueCtx"), "subscription{}", "", (map[string]interface{})(nil)).Return((<-chan interface{})(c), nil).Run(func(args mock.Arguments) { + m.done <- true + }).Once() + }, + setup: func() testMocker { + mockSvc := new(mockGraphQLService) + return testMocker{ + handler: NewHandlerFunc(mockSvc, nil), + mockSvc: mockSvc, + done: make(chan bool, 1), + } + }, + want: Want{ + expectedSubprotocol: ProtocolGraphQLWS, + assertion: func(t *testing.T, conn *websocket.Conn) { + err := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"connection_init"}`)) + require.NoError(t, err) + + err = conn.WriteMessage(websocket.TextMessage, []byte(`{"id":"1","type":"start","payload":{"query":"subscription{}"}}`)) + require.NoError(t, err) + + _, _, err = conn.ReadMessage() + require.NoError(t, err) + }, + }, + }, + "graphql-transport-ws protocol unsupported ": { + args: Args{ + isWebSocketTest: true, + subprotocols: []string{ProtocolGraphQLTransportWS}, + }, + setup: func() testMocker { + return testMocker{handler: NewHandlerFunc(nil, nil)} + }, + want: Want{ + expectedSubprotocol: ProtocolGraphQLTransportWS, + assertion: func(t *testing.T, conn *websocket.Conn) { + var closeError *websocket.CloseError + + _, _, err := conn.ReadMessage() + assert.ErrorAs(t, err, &closeError, "Expected server to close connection for placeholder") + }, + }, + }, + "unsupported protocol error": { + args: Args{ + isWebSocketTest: true, + subprotocols: []string{"unsupported-protocol"}, + }, + setup: func() testMocker { + return testMocker{handler: NewHandlerFunc(nil, nil)} + }, + want: Want{ + assertion: func(t *testing.T, conn *websocket.Conn) { + assert.Equal(t, "", conn.Subprotocol(), "Expected no subprotocol to be selected") + + var closeError *websocket.CloseError + + _, _, err := conn.ReadMessage() + assert.ErrorAs(t, err, &closeError, "Expected server to close connection for unsupported protocol") + }, + }, + }, + "HTTP fallback ok": { + mockExpectations: func(m testMocker) { + m.mockHTTP.On("ServeHTTP", mock.Anything, mock.MatchedBy(func(r *http.Request) bool { return r.Method == http.MethodGet })).Run(func(args mock.Arguments) { + m.done <- true + }).Once() + }, + setup: func() testMocker { + mockHTTP := new(mockHTTPHandler) + return testMocker{ + handler: NewHandlerFunc(nil, mockHTTP), + mockHTTP: mockHTTP, + done: make(chan bool, 1), + } + }, + want: Want{ + assertion: func(t *testing.T, conn *websocket.Conn) {}, + }, + }, + } + + for name, tt := range testTable { + tt := tt + + t.Run(name, func(t *testing.T) { + t.Parallel() + + mocker := tt.setup() + if tt.mockExpectations != nil { + tt.mockExpectations(mocker) + } + + server := httptest.NewServer(mocker.handler) + defer server.Close() + + if !tt.args.isWebSocketTest { + _, err := http.Get(server.URL) + require.NoError(t, err) + + select { + case <-mocker.done: + case <-time.After(1 * time.Second): + t.Fatal("timed out waiting for ServeHTTP mock to be called") + } + + mocker.mockHTTP.AssertExpectations(t) + return + } + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + dialer := websocket.Dialer{Subprotocols: tt.args.subprotocols} + conn, _, err := dialer.Dial(wsURL, nil) + + if tt.want.checkError != nil { + tt.want.checkError(t, err) + return + } + + require.NoError(t, err) + defer conn.Close() + + assert.Equal(t, tt.want.expectedSubprotocol, conn.Subprotocol()) + + if tt.want.assertion != nil { + tt.want.assertion(t, conn) + } + + if mocker.done != nil { + select { + case <-mocker.done: + case <-time.After(1 * time.Second): + t.Fatal("timed out waiting for mock expectation to be met") + } + } + + if mocker.mockSvc != nil { + mocker.mockSvc.AssertExpectations(t) + } + + if mocker.mockHTTP != nil { + mocker.mockHTTP.AssertExpectations(t) + } + }) + } +} + func TestContextGenerators(t *testing.T) { contextBuilderFunc := func() ContextGeneratorFunc { return func(ctx context.Context, r *http.Request) (context.Context, error) { @@ -28,7 +233,7 @@ func TestContextGenerators(t *testing.T) { } type want struct { Context context.Context - Error string + Error string } testTable := map[string]struct { @@ -65,6 +270,5 @@ func TestContextGenerators(t *testing.T) { assert.Equal(t, tt.Want.Context, ctx, "New context generated") require.NoError(t, err, "Error generating context") }) - } -} \ No newline at end of file +} diff --git a/graphqlws/internal/connection/connection.go b/graphqlws/internal/connection/connection.go index fb253fe..c97eca4 100644 --- a/graphqlws/internal/connection/connection.go +++ b/graphqlws/internal/connection/connection.go @@ -7,6 +7,8 @@ import ( "fmt" "sync" "time" + + "github.com/graph-gophers/graphql-transport-ws/graphqlws/internal/gql" ) type operationMessageType string @@ -50,14 +52,9 @@ type startMessagePayload struct { type initMessagePayload struct{} -// GraphQLService interface -type GraphQLService interface { - Subscribe(ctx context.Context, document string, operationName string, variableValues map[string]interface{}) (payloads <-chan interface{}, err error) -} - type connection struct { cancel func() - service GraphQLService + service gql.GraphQLService writeTimeout time.Duration ws wsConnection } @@ -111,7 +108,7 @@ func WriteTimeout(d time.Duration) Option { // Connect implements the apollographql subscriptions-transport-ws protocol@v0.9.4 // https://github.com/apollographql/subscriptions-transport-ws/blob/v0.9.4/PROTOCOL.md -func Connect(ctx context.Context, ws wsConnection, service GraphQLService, options ...Option) func() { +func Connect(ctx context.Context, ws wsConnection, service gql.GraphQLService, options ...Option) func() { conn := &connection{ service: service, ws: ws, diff --git a/graphqlws/internal/gql/gql.go b/graphqlws/internal/gql/gql.go new file mode 100644 index 0000000..dd7fbde --- /dev/null +++ b/graphqlws/internal/gql/gql.go @@ -0,0 +1,7 @@ +package gql + +import "context" + +type GraphQLService interface { + Subscribe(ctx context.Context, document string, operationName string, variableValues map[string]interface{}) (payloads <-chan interface{}, err error) +} diff --git a/graphqlws/internal/transport/transport.go b/graphqlws/internal/transport/transport.go new file mode 100644 index 0000000..3845614 --- /dev/null +++ b/graphqlws/internal/transport/transport.go @@ -0,0 +1,346 @@ +package transport + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sync" + "time" + + "github.com/graph-gophers/graphql-transport-ws/graphqlws/internal/gql" +) + +// operationMap holds active subscriptions. +type operationMap struct { + ops map[string]func() + mtx *sync.RWMutex +} + +func newOperationMap() operationMap { + return operationMap{ + ops: make(map[string]func()), + mtx: &sync.RWMutex{}, + } +} + +func (o *operationMap) add(name string, done func()) { + o.mtx.Lock() + o.ops[name] = done + o.mtx.Unlock() +} + +func (o *operationMap) get(name string) (func(), bool) { + o.mtx.RLock() + f, ok := o.ops[name] + o.mtx.RUnlock() + return f, ok +} + +func (o *operationMap) delete(name string) { + o.mtx.Lock() + delete(o.ops, name) + o.mtx.Unlock() +} + +type operationMessageType string + +// https://github.com/graphql/graphql-over-http/blob/main/rfcs/GraphQLOverWebSocket.md +const ( + typeConnectionInit operationMessageType = "connection_init" + typeConnectionAck operationMessageType = "connection_ack" + typePing operationMessageType = "ping" + typePong operationMessageType = "pong" + typeSubscribe operationMessageType = "subscribe" + typeNext operationMessageType = "next" + typeError operationMessageType = "error" + typeComplete operationMessageType = "complete" +) + +type operationMessage struct { + ID string `json:"id,omitempty"` + Payload json.RawMessage `json:"payload,omitempty"` + Type operationMessageType `json:"type"` +} + +type subscribeMessagePayload struct { + OperationName string `json:"operationName"` + Query string `json:"query"` + Variables map[string]interface{} `json:"variables"` +} + +type wsConnection interface { + Close() error + ReadJSON(v interface{}) error + SetReadLimit(limit int64) + SetWriteDeadline(t time.Time) error + WriteJSON(v interface{}) error +} + +type connection struct { + cancel func() + service gql.GraphQLService + writeTimeout time.Duration + ws wsConnection +} + +type sendFunc func(id string, omType operationMessageType, payload json.RawMessage) + +type Option func(conn *connection) + +func ReadLimit(limit int64) Option { + return func(conn *connection) { + conn.ws.SetReadLimit(limit) + } +} + +func WriteTimeout(d time.Duration) Option { + return func(conn *connection) { + conn.writeTimeout = d + } +} + +func Connect(ctx context.Context, ws wsConnection, service gql.GraphQLService, options ...Option) { + conn := &connection{ + service: service, + ws: ws, + } + + defaultOpts := []Option{ + ReadLimit(4096), + WriteTimeout(time.Second * 3), + } + + for _, opt := range append(defaultOpts, options...) { + opt(conn) + } + + ctx, cancel := context.WithCancel(ctx) + conn.cancel = cancel + conn.readLoop(ctx, conn.writeLoop(ctx)) +} + +func (conn *connection) writeLoop(ctx context.Context) sendFunc { + stop := make(chan struct{}) + out := make(chan *operationMessage, 1) // Using a small buffer can sometimes help, but is not essential for the fix. + + send := func(id string, omType operationMessageType, payload json.RawMessage) { + select { + case <-stop: + return + case out <- &operationMessage{ID: id, Type: omType, Payload: payload}: + } + } + + go func() { + defer close(stop) + defer conn.ws.Close() + + for { + select { + case msg := <-out: + if err := conn.ws.SetWriteDeadline(time.Now().Add(conn.writeTimeout)); err != nil { + return + } + if err := conn.ws.WriteJSON(msg); err != nil { + return + } + case <-ctx.Done(): + // Context is canceled. Drain any remaining messages in the out channel. + for { + select { + case msg := <-out: + // Still attempt to write pending messages + conn.ws.SetWriteDeadline(time.Now().Add(conn.writeTimeout)) + if err := conn.ws.WriteJSON(msg); err != nil { + // On error, we can't do much more, so exit. + return + } + default: + // The out channel is empty, we can now safely exit the goroutine. + return + } + } + } + } + }() + + return send +} + +func (conn *connection) close() { + conn.cancel() +} + +func (conn *connection) drainChannel(out chan *operationMessage) { + for { + select { + case msg := <-out: + conn.ws.SetWriteDeadline(time.Now().Add(conn.writeTimeout)) + if err := conn.ws.WriteJSON(msg); err != nil { + return + } + default: + return + } + } +} + +func (conn *connection) readLoop(ctx context.Context, send sendFunc) { + defer conn.close() + + ops := newOperationMap() + initDone := false + msgChan := make(chan *operationMessage) + errChan := make(chan error, 1) + + go func() { + for { + var msg operationMessage + if err := conn.ws.ReadJSON(&msg); err != nil { + errChan <- err + return + } + msgChan <- &msg + } + }() + + initTimer := time.NewTimer(conn.writeTimeout) + defer initTimer.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-errChan: + // Read error occurred (e.g., client closed connection) + // Find and cancel all active operations + ops.mtx.Lock() + for id, cancel := range ops.ops { + cancel() + delete(ops.ops, id) + } + ops.mtx.Unlock() + return + case <-initTimer.C: + if !initDone { + // Client failed to send connection_init in time + send("", typeError, errPayload(errors.New("connection initialisation timeout"))) + return + } + case msg := <-msgChan: + if !initDone { + initTimer.Stop() + + if msg.Type != typeConnectionInit { + send("", typeError, errPayload(errors.New("connection_init message not received"))) + return + } + + // TODO: Add payload handling for auth here if needed + send("", typeConnectionAck, nil) + initDone = true + + continue + } + + err := conn.processMessages(ctx, msg, send, ops) + if err != nil { + return + } + } + } +} + +// processMessages handles the different types messages in the graphql-transport-ws subprotocol +func (conn *connection) processMessages(ctx context.Context, msg *operationMessage, send sendFunc, ops operationMap) error { + switch msg.Type { + case typeConnectionInit: + send("", typeError, errPayload(errors.New("connection_init sent twice"))) + return errors.New("error connection_init sent twice, the connection will be closed") + + case typePing: + send("", typePong, msg.Payload) + + case typeSubscribe: + if msg.ID == "" { + send("", typeError, errPayload(errors.New("missing ID for subscribe operation"))) + return nil + } + + if _, exists := ops.get(msg.ID); exists { + send(msg.ID, typeError, errPayload(errors.New("duplicate operation ID"))) + return nil + } + + var payload subscribeMessagePayload + + if err := json.Unmarshal(msg.Payload, &payload); err != nil { + send(msg.ID, typeError, errPayload(fmt.Errorf("invalid subscribe payload: %w", err))) + return nil + } + + opCtx, opCancel := context.WithCancel(ctx) + ops.add(msg.ID, opCancel) + + go conn.runSubscription(opCtx, msg.ID, payload, send, ops) + + case typeComplete: + if msg.ID == "" { + send("", typeError, errPayload(errors.New("missing ID for complete operation"))) + return nil + } + + if opCancel, ok := ops.get(msg.ID); ok { + opCancel() + ops.delete(msg.ID) + } + + default: + send(msg.ID, typeError, errPayload(fmt.Errorf("unknown message type: %s", msg.Type))) + } + + return nil +} + +func (conn *connection) runSubscription(ctx context.Context, id string, payload subscribeMessagePayload, send sendFunc, ops operationMap) { + defer ops.delete(id) + + c, err := conn.service.Subscribe(ctx, payload.Query, payload.OperationName, payload.Variables) + if err != nil { + send(id, typeError, errPayload(err)) + return + } + + for { + select { + case <-ctx.Done(): + return + case data, more := <-c: + if !more { + // Subscription stream closed + send(id, typeComplete, nil) + return + } + + // Stream has data, send a 'next' message + jsonPayload, err := json.Marshal(data) + if err != nil { + send(id, typeError, errPayload(fmt.Errorf("failed to marshal payload: %w", err))) + continue + } + + send(id, typeNext, jsonPayload) + } + } +} + +func errPayload(err error) json.RawMessage { + b, _ := json.Marshal(struct { + Message string `json:"message"` + }{ + Message: err.Error(), + }) + + return b +} diff --git a/graphqlws/internal/transport/transport_test.go b/graphqlws/internal/transport/transport_test.go new file mode 100644 index 0000000..595deb0 --- /dev/null +++ b/graphqlws/internal/transport/transport_test.go @@ -0,0 +1,432 @@ +package transport + +import ( + "context" + "encoding/json" + "errors" + "sort" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type mockGraphQLService struct { + mock.Mock +} + +// Subscribe is a mock implementation of a GraphQL subscription operation. +// It records the call with the provided context and GraphQL parameters, +// then returns a pre-configured channel and error based on the test setup. +func (s *mockGraphQLService) Subscribe(ctx context.Context, document string, operationName string, variableValues map[string]interface{}) (<-chan interface{}, error) { + args := s.Called(ctx, document, operationName, variableValues) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(<-chan interface{}), args.Error(1) +} + +type mockConnection struct { + in chan json.RawMessage + out chan json.RawMessage + closeCalled chan bool + readLimit int64 + writeTimeout time.Duration + mtx sync.Mutex + isClosed bool +} + +func newMockConnection() *mockConnection { + return &mockConnection{ + in: make(chan json.RawMessage, 10), + out: make(chan json.RawMessage, 10), + closeCalled: make(chan bool, 1), + } +} + +// ReadJSON reads the next JSON-formatted message from the connection and unmarshals it +func (ws *mockConnection) ReadJSON(v interface{}) error { + msg, ok := <-ws.in + if !ok { + return errors.New("connection closed") + } + + return json.Unmarshal(msg, v) +} + +// WriteJSON marshals the value v to JSON and sends it as a message over the connection +func (ws *mockConnection) WriteJSON(v interface{}) error { + ws.mtx.Lock() + defer ws.mtx.Unlock() + + if ws.isClosed { + return errors.New("writing to closed connection") + } + + data, err := json.Marshal(v) + if err != nil { + return err + } + + ws.out <- data + return nil +} + +func (ws *mockConnection) SetReadLimit(limit int64) { + ws.readLimit = limit +} + +func (ws *mockConnection) SetWriteDeadline(t time.Time) error { + ws.writeTimeout = time.Until(t) + return nil +} + +func (ws *mockConnection) Close() error { + ws.mtx.Lock() + defer ws.mtx.Unlock() + + if !ws.isClosed { + ws.isClosed = true + + close(ws.closeCalled) + + close(ws.out) + } + + return nil +} + +type mocker struct { + conn *mockConnection + mockSvc *mockGraphQLService +} + +func setupTest(t *testing.T) mocker { + t.Helper() + + return mocker{ + conn: newMockConnection(), + mockSvc: new(mockGraphQLService), + } +} + +func TestOperationMap(t *testing.T) { + t.Parallel() + fn1 := func() {} + + type args struct { + key string + fn func() + startMap map[string]func() + } + type want struct { + fn func() + ok bool + } + + t.Run("add_and_get", func(t *testing.T) { + t.Parallel() + testTable := map[string]struct { + args args + want want + }{ + "add_and_get_existing_key": { + args: args{key: "key1", fn: fn1, startMap: map[string]func(){}}, + want: want{fn: fn1, ok: true}, + }, + "get_non_existent_key": { + args: args{key: "key2", startMap: map[string]func(){}}, + want: want{fn: nil, ok: false}, + }, + } + + for name, tt := range testTable { + tt := tt + t.Run(name, func(t *testing.T) { + t.Parallel() + + om := newOperationMap() + om.ops = tt.args.startMap + + if tt.args.fn != nil { + om.add(tt.args.key, tt.args.fn) + } + + gotFn, gotOk := om.get(tt.args.key) + + assert.Equal(t, tt.want.ok, gotOk, "Expected 'ok' to match") + + if tt.want.fn != nil { + assert.NotNil(t, gotFn, "Expected function to be non-nil") + } else { + assert.Nil(t, gotFn, "Expected function to be nil") + } + }) + } + }) + + t.Run("delete", func(t *testing.T) { + t.Parallel() + testTable := map[string]struct { + args args + want map[string]func() + }{ + "delete_existing_key": { + args: args{key: "key1", startMap: map[string]func(){"key1": fn1, "key2": fn1}}, + want: map[string]func(){"key2": fn1}, + }, + "delete_non_existent_key": { + args: args{key: "key3", startMap: map[string]func(){"key1": fn1, "key2": fn1}}, + want: map[string]func(){"key1": fn1, "key2": fn1}, + }, + "delete_from_empty_map": { + args: args{key: "key1", startMap: map[string]func(){}}, + want: map[string]func(){}, + }, + } + + for name, tt := range testTable { + tt := tt + t.Run(name, func(t *testing.T) { + t.Parallel() + om := newOperationMap() + om.ops = tt.args.startMap + + om.delete(tt.args.key) + + gotKeys := getMapKeys(om.ops) + wantKeys := getMapKeys(tt.want) + assert.ElementsMatch(t, wantKeys, gotKeys, "Map keys do not match after delete") + }) + } + }) +} + +func TestConnect(t *testing.T) { + t.Parallel() + + type Args struct { + clientMessages []string + options []Option + } + type Want struct { + serverMessages []string + assertClose bool + } + + testTable := map[string]struct { + setup func(t *testing.T) mocker + mockExpectations func(h mocker) + args Args + want Want + }{ + "Successful subscription": { + setup: setupTest, + mockExpectations: func(h mocker) { + c := make(chan interface{}, 1) + c <- json.RawMessage(`{"data":{"foo":"bar"}}`) + + close(c) + + h.mockSvc.On("Subscribe", mock.Anything, "sub { hello }", "MySub", mock.Anything).Return((<-chan interface{})(c), nil).Once() + }, + args: Args{ + clientMessages: []string{`{"type":"connection_init"}`, `{"id":"1","type":"subscribe","payload":{"query":"sub { hello }","operationName":"MySub","variables":{}}}`}, + }, + want: Want{ + serverMessages: []string{`{"type":"connection_ack"}`, `{"id":"1","type":"next","payload":{"data":{"foo":"bar"}}}`, `{"id":"1","type":"complete"}`}, + assertClose: false, + }, + }, + "Error if subscribe sent first": { + setup: setupTest, + args: Args{ + clientMessages: []string{`{"id":"1","type":"subscribe","payload":{}}`}, + }, + want: Want{ + serverMessages: []string{`{"type":"error","payload":{"message":"connection_init message not received"}}`}, + assertClose: true, + }, + }, + "Error if connection_init sent twice": { + setup: setupTest, + args: Args{ + clientMessages: []string{`{"type":"connection_init"}`, `{"type":"connection_init"}`}, + }, + want: Want{ + serverMessages: []string{`{"type":"connection_ack"}`, `{"type":"error","payload":{"message":"connection_init sent twice"}}`}, + assertClose: true, + }, + }, + "Ping/Pong": { + setup: setupTest, + args: Args{ + clientMessages: []string{`{"type":"connection_init"}`, `{"type":"ping","payload":{"key":"val"}}`}, + }, + want: Want{ + serverMessages: []string{`{"type":"connection_ack"}`, `{"type":"pong","payload":{"key":"val"}}`}, + assertClose: false, + }, + }, + "Error on subscribe with empty ID": { + setup: setupTest, + args: Args{ + clientMessages: []string{`{"type":"connection_init"}`, `{"id":"","type":"subscribe","payload":{}}`}, + }, + want: Want{ + serverMessages: []string{`{"type":"connection_ack"}`, `{"type":"error","payload":{"message":"missing ID for subscribe operation"}}`}, + assertClose: false, + }, + }, + "Error on subscribe with duplicate ID": { + setup: setupTest, + mockExpectations: func(h mocker) { + c := make(chan interface{}) + h.mockSvc.On("Subscribe", mock.Anything, "sub { hello }", "", mock.Anything).Return((<-chan interface{})(c), nil).Once() + }, + args: Args{ + clientMessages: []string{`{"type":"connection_init"}`, `{"id":"1","type":"subscribe","payload":{"query":"sub { hello }"}}`, `{"id":"1","type":"subscribe","payload":{"query":"sub { hello }"}}`}, + }, + want: Want{ + serverMessages: []string{`{"type":"connection_ack"}`, `{"id":"1","type":"error","payload":{"message":"duplicate operation ID"}}`}, + assertClose: false, + }, + }, + "Complete message stops subscription": { + setup: setupTest, + mockExpectations: func(h mocker) { + c := make(chan interface{}) + h.mockSvc.On("Subscribe", mock.Anything, "sub { hello }", "", mock.Anything).Return((<-chan interface{})(c), nil).Once() + }, + args: Args{ + clientMessages: []string{`{"type":"connection_init"}`, `{"id":"1","type":"subscribe","payload":{"query":"sub { hello }"}}`, `{"id":"1","type":"complete"}`}, + }, + want: Want{ + serverMessages: []string{`{"type":"connection_ack"}`}, + assertClose: false, + }, + }, + "Subscription error": { + setup: setupTest, + mockExpectations: func(h mocker) { + h.mockSvc.On("Subscribe", mock.Anything, "sub { hello }", "", mock.Anything).Return((<-chan interface{})(nil), errors.New("test sub error")).Once() + }, + args: Args{ + clientMessages: []string{`{"type":"connection_init"}`, `{"id":"1s","type":"subscribe","payload":{"query":"sub { hello }"}}`}, + }, + want: Want{ + serverMessages: []string{`{"type":"connection_ack"}`, `{"id":"1s","type":"error","payload":{"message":"test sub error"}}`}, + assertClose: false, + }, + }, + "Connection init timeout": { + setup: setupTest, + args: Args{ + options: []Option{WriteTimeout(50 * time.Millisecond)}, + }, + want: Want{ + serverMessages: []string{`{"type":"error","payload":{"message":"connection initialisation timeout"}}`}, + assertClose: true, + }, + }, + } + + for name, tt := range testTable { + tt := tt + + t.Run(name, func(t *testing.T) { + t.Parallel() + + h := tt.setup(t) + + if tt.mockExpectations != nil { + tt.mockExpectations(h) + } + + go Connect(context.Background(), h.conn, h.mockSvc, tt.args.options...) + + go func() { + for _, msg := range tt.args.clientMessages { + h.conn.in <- json.RawMessage(msg) + } + + if !tt.want.assertClose { + time.Sleep(100 * time.Millisecond) + close(h.conn.in) + } + }() + + receivedMessages := receiveTestMessages(t, h) + + assert.Equal(t, len(tt.want.serverMessages), len(receivedMessages), "Unexpected number of messages received") + + for i, expectedMsg := range tt.want.serverMessages { + if i >= len(receivedMessages) { + break + } + requireEqualJSON(t, expectedMsg, receivedMessages[i], "Message %d mismatch", i) + } + + if tt.want.assertClose { + select { + case <-h.conn.closeCalled: + // Server closed connection as expected + case <-time.After(1 * time.Second): + t.Fatal("timed out waiting for server to close connection") + } + } + + h.mockSvc.AssertExpectations(t) + }) + } +} + +func getMapKeys(m map[string]func()) []string { + keys := make([]string, 0, len(m)) + + for k := range m { + keys = append(keys, k) + } + + sort.Strings(keys) + return keys +} + +func requireEqualJSON(t *testing.T, expected string, actual json.RawMessage, msgAndArgs ...interface{}) { + t.Helper() + + var expJSON, actJSON interface{} + + err := json.Unmarshal([]byte(expected), &expJSON) + require.NoError(t, err, "Failed to unmarshal expected JSON") + + err = json.Unmarshal(actual, &actJSON) + require.NoError(t, err, "Failed to unmarshal actual JSON") + + assert.Equal(t, expJSON, actJSON, msgAndArgs...) +} + +// receiveTestMessages handles the logic of listening for server messages with a timeout +func receiveTestMessages(t *testing.T, h mocker) []json.RawMessage { + timeout := time.After(1 * time.Second) + + var receivedMessages []json.RawMessage + + for { + select { + case actualMsg, ok := <-h.conn.out: + if !ok { + // Channel is closed + // Return the collected messages. + return receivedMessages + } + receivedMessages = append(receivedMessages, actualMsg) + case <-timeout: + t.Fatalf("timed out waiting for server messages") + return nil + } + } +}