Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvK
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
Expand Down
110 changes: 75 additions & 35 deletions graphqlws/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,28 @@ package graphqlws

import (
"context"
"fmt"
"net/http"
"time"

"github.com/gorilla/websocket"
"github.com/graph-gophers/graphql-transport-ws/graphqlws/internal/transport"

"github.com/graph-gophers/graphql-transport-ws/graphqlws/internal/connection"
"github.com/graph-gophers/graphql-transport-ws/graphqlws/internal/gql"
)

// ProtocolGraphQLWS is websocket subprotocol ID for GraphQL over WebSocket
// see https://github.com/apollographql/subscriptions-transport-ws
const ProtocolGraphQLWS = "graphql-ws"
const (
// ProtocolGraphQLTransportWS is the modern websocket subprotocol ID for GraphQL over WebSocket.
// see https://github.com/enisdenjo/graphql-ws
ProtocolGraphQLTransportWS = "graphql-transport-ws"
// ProtocolGraphQLWS is the deprecated websocket subprotocol ID for GraphQL over WebSocket.
// see https://github.com/apollographql/subscriptions-transport-ws
ProtocolGraphQLWS = "graphql-ws"
)

var defaultUpgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
Subprotocols: []string{ProtocolGraphQLWS},
Subprotocols: []string{ProtocolGraphQLTransportWS, ProtocolGraphQLWS},
}

type handler struct {
Expand Down Expand Up @@ -55,7 +61,38 @@ type Option interface {

type options struct {
contextGenerators []ContextGenerator
connectOptions []connection.Option
readLimit int64
writeTimeout time.Duration
hasReadLimit bool
hasWriteTimeout bool
}

func (o *options) connectionOptions() []connection.Option {
var connOptions []connection.Option

if o.hasReadLimit {
connOptions = append(connOptions, connection.ReadLimit(o.readLimit))
}

if o.hasWriteTimeout {
connOptions = append(connOptions, connection.WriteTimeout(o.writeTimeout))
}

return connOptions
}

func (o *options) transportOptions() []transport.Option {
var connOptions []transport.Option

if o.hasReadLimit {
connOptions = append(connOptions, transport.ReadLimit(o.readLimit))
}

if o.hasWriteTimeout {
connOptions = append(connOptions, transport.WriteTimeout(o.writeTimeout))
}

return connOptions
}

type optionFunc func(*options)
Expand All @@ -75,16 +112,16 @@ func WithContextGenerator(f ContextGenerator) Option {
// WithReadLimit limits the maximum size of incoming messages
func WithReadLimit(limit int64) Option {
return optionFunc(func(o *options) {
connOpt := connection.ReadLimit(limit)
o.connectOptions = append(o.connectOptions, connOpt)
o.readLimit = limit
o.hasReadLimit = true
})
}

// WithWriteTimeout sets a timeout for outgoing messages
func WithWriteTimeout(d time.Duration) Option {
return optionFunc(func(o *options) {
connOpt := connection.WriteTimeout(d)
o.connectOptions = append(o.connectOptions, connOpt)
o.writeTimeout = d
o.hasWriteTimeout = true
})
}

Expand All @@ -99,47 +136,50 @@ func applyOptions(opts ...Option) *options {
}

// NewHandlerFunc returns an http.HandlerFunc that supports GraphQL over websockets
func NewHandlerFunc(svc connection.GraphQLService, httpHandler http.Handler, options ...Option) http.HandlerFunc {
func NewHandlerFunc(svc gql.GraphQLService, httpHandler http.Handler, options ...Option) http.HandlerFunc {
handler := NewHandler()
return handler.NewHandlerFunc(svc, httpHandler, options...)
}

// NewHandlerFunc returns an http.HandlerFunc that supports GraphQL over websockets
func (h *handler) NewHandlerFunc(svc connection.GraphQLService, httpHandler http.Handler, options ...Option) http.HandlerFunc {
func (h *handler) NewHandlerFunc(svc gql.GraphQLService, httpHandler http.Handler, options ...Option) http.HandlerFunc {
o := applyOptions(options...)

return func(w http.ResponseWriter, r *http.Request) {
for _, subprotocol := range websocket.Subprotocols(r) {
if subprotocol != ProtocolGraphQLWS {
continue
}

ctx, err := buildContext(r, o.contextGenerators)
if err != nil {
w.Header().Set("X-WebSocket-Upgrade-Failure", err.Error())
if !websocket.IsWebSocketUpgrade(r) {
if httpHandler == nil {
http.Error(w, "Not Found", http.StatusNotFound)
return
}

ws, err := h.Upgrader.Upgrade(w, r, nil)
if err != nil {
w.Header().Set("X-WebSocket-Upgrade-Failure", err.Error())
return
}
httpHandler.ServeHTTP(w, r)
return
}

if ws.Subprotocol() != ProtocolGraphQLWS {
w.Header().Set("X-WebSocket-Upgrade-Failure",
fmt.Sprintf("upgraded websocket has wrong subprotocol (%s)", ws.Subprotocol()))
ws.Close()
return
}
ctx, err := buildContext(r, o.contextGenerators)
if err != nil {
w.Header().Set("X-WebSocket-Upgrade-Failure", err.Error())
return
}

go connection.Connect(ctx, ws, svc, o.connectOptions...)
ws, err := h.Upgrader.Upgrade(w, r, nil)
if err != nil {
// UPGRADE FAILED: The Upgrader has already written an error response.
// Do not call the httpHandler
return
}

// Fallback to HTTP
w.Header().Set("X-WebSocket-Upgrade-Failure", "no subprotocols available")
httpHandler.ServeHTTP(w, r)
switch ws.Subprotocol() {
case ProtocolGraphQLWS:
go connection.Connect(ctx, ws, svc, o.connectionOptions()...)

case ProtocolGraphQLTransportWS:
go transport.Connect(ctx, ws, svc, o.transportOptions()...)

default:
w.Header().Set("X-WebSocket-Upgrade-Failure", "unsupported subprotocol")
ws.Close()
}
}
}

Expand Down
Loading