From 15d4f05df6a343ed2fb72f40651fd4c1be245e0b Mon Sep 17 00:00:00 2001 From: cbillett Date: Tue, 2 Apr 2019 09:27:37 -0400 Subject: [PATCH 1/2] added authentication support through new "Authentication" option --- graphqlws/http.go | 10 ++++-- graphqlws/internal/connection/connection.go | 38 +++++++++++++++++---- 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/graphqlws/http.go b/graphqlws/http.go index 8351893..dd93e95 100644 --- a/graphqlws/http.go +++ b/graphqlws/http.go @@ -1,9 +1,8 @@ package graphqlws import ( - "net/http" - "github.com/gorilla/websocket" + "net/http" "github.com/graph-gophers/graphql-transport-ws/graphqlws/internal/connection" ) @@ -17,6 +16,11 @@ var upgrader = websocket.Upgrader{ // NewHandlerFunc returns an http.HandlerFunc that supports GraphQL over websockets func NewHandlerFunc(svc connection.GraphQLService, httpHandler http.Handler) http.HandlerFunc { + return NewHandlerFuncWithAuth(svc, httpHandler, nil) +} + +// NewHandlerFunc returns an http.HandlerFunc that supports GraphQL over websockets +func NewHandlerFuncWithAuth(svc connection.GraphQLService, httpHandler http.Handler, authFunc connection.AuthenticateFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { for _, subprotocol := range websocket.Subprotocols(r) { if subprotocol == "graphql-ws" { @@ -30,7 +34,7 @@ func NewHandlerFunc(svc connection.GraphQLService, httpHandler http.Handler) htt return } - go connection.Connect(ws, svc) + go connection.Connect(ws, svc, connection.Authentication(authFunc)) return } } diff --git a/graphqlws/internal/connection/connection.go b/graphqlws/internal/connection/connection.go index 94c9167..e9efee1 100644 --- a/graphqlws/internal/connection/connection.go +++ b/graphqlws/internal/connection/connection.go @@ -47,18 +47,26 @@ type startMessagePayload struct { Variables map[string]interface{} `json:"variables"` } -type initMessagePayload struct{} - // GraphQLService interface type GraphQLService interface { Subscribe(ctx context.Context, document string, operationName string, variableValues map[string]interface{}) (payloads <-chan interface{}, err error) } +type AuthenticateFunc func(ctx context.Context, payload map[string]interface{}) error + type connection struct { - cancel func() - service GraphQLService - writeTimeout time.Duration - ws wsConnection + cancel func() + service GraphQLService + writeTimeout time.Duration + ws wsConnection + authenticated bool + authenticateFunc AuthenticateFunc +} + +func Authentication(f AuthenticateFunc) func(conn *connection) { + return func(conn *connection) { + conn.authenticateFunc = f + } } // ReadLimit limits the maximum size of incoming messages @@ -159,15 +167,31 @@ func (conn *connection) readLoop(ctx context.Context, send sendFunc) { switch msg.Type { case typeConnectionInit: - var initMsg initMessagePayload + + var initMsg map[string]interface{} if err := json.Unmarshal(msg.Payload, &initMsg); err != nil { ep := errPayload(fmt.Errorf("invalid payload for type: %s", msg.Type)) send("", typeConnectionError, ep) continue } + + if conn.authenticateFunc != nil { + if err := conn.authenticateFunc(ctx, initMsg); err != nil { + send("", typeConnectionError, errPayload(err)) + continue + } + } + conn.authenticated = true send("", typeConnectionAck, nil) case typeStart: + + if !conn.authenticated && conn.authenticateFunc != nil { + ep := errPayload(errors.New("authentication required.")) + send("", typeConnectionError, ep) + continue + } + // TODO: check an operation with the same ID hasn't been started already if msg.ID == "" { ep := errPayload(errors.New("missing ID for start operation")) From 6cf57af6d464aa2a2f3e9c2d68ef174560d1bddf Mon Sep 17 00:00:00 2001 From: cbillett Date: Tue, 2 Apr 2019 10:13:33 -0400 Subject: [PATCH 2/2] added http request to authentication "Authentication" option --- graphqlws/http.go | 2 +- graphqlws/internal/connection/connection.go | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/graphqlws/http.go b/graphqlws/http.go index dd93e95..56b424d 100644 --- a/graphqlws/http.go +++ b/graphqlws/http.go @@ -34,7 +34,7 @@ func NewHandlerFuncWithAuth(svc connection.GraphQLService, httpHandler http.Hand return } - go connection.Connect(ws, svc, connection.Authentication(authFunc)) + go connection.Connect(ws, svc, connection.Authentication(r, authFunc)) return } } diff --git a/graphqlws/internal/connection/connection.go b/graphqlws/internal/connection/connection.go index e9efee1..84006f2 100644 --- a/graphqlws/internal/connection/connection.go +++ b/graphqlws/internal/connection/connection.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "time" ) @@ -52,7 +53,7 @@ type GraphQLService interface { Subscribe(ctx context.Context, document string, operationName string, variableValues map[string]interface{}) (payloads <-chan interface{}, err error) } -type AuthenticateFunc func(ctx context.Context, payload map[string]interface{}) error +type AuthenticateFunc func(ctx context.Context, r *http.Request, payload map[string]interface{}) error type connection struct { cancel func() @@ -61,11 +62,13 @@ type connection struct { ws wsConnection authenticated bool authenticateFunc AuthenticateFunc + request *http.Request } -func Authentication(f AuthenticateFunc) func(conn *connection) { +func Authentication(r *http.Request, f AuthenticateFunc) func(conn *connection) { return func(conn *connection) { conn.authenticateFunc = f + conn.request = r } } @@ -176,7 +179,7 @@ func (conn *connection) readLoop(ctx context.Context, send sendFunc) { } if conn.authenticateFunc != nil { - if err := conn.authenticateFunc(ctx, initMsg); err != nil { + if err := conn.authenticateFunc(ctx, conn.request, initMsg); err != nil { send("", typeConnectionError, errPayload(err)) continue }