Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions peer.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package turnpike

import (
"context"
"fmt"
"time"
)
Expand All @@ -24,6 +25,12 @@ type Peer interface {

// Receive returns a channel of messages coming from the peer.
Receive() <-chan Message

//AddIncomeMiddleware implements preprocess income messages
AddIncomeMiddleware(f func(Message) (Message, error))

//GetContext returns context
GetContext() context.Context
}

// GetMessageTimeout is a convenience function to get a single message from a
Expand Down
13 changes: 13 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package turnpike

import (
"context"
"fmt"
)

Expand All @@ -27,10 +28,12 @@ func localPipe() (*localPeer, *localPeer) {
a := &localPeer{
incoming: bToA,
outgoing: aToB,
ctx: context.TODO(),
}
b := &localPeer{
incoming: aToB,
outgoing: bToA,
ctx: context.TODO(),
}

return a, b
Expand All @@ -39,6 +42,7 @@ func localPipe() (*localPeer, *localPeer) {
type localPeer struct {
outgoing chan<- Message
incoming <-chan Message
ctx context.Context
}

func (s *localPeer) Receive() <-chan Message {
Expand All @@ -54,3 +58,12 @@ func (s *localPeer) Close() error {
close(s.outgoing)
return nil
}

//todo
func (s *localPeer) AddIncomeMiddleware(f func(Message) (Message, error)) {

}

func (s *localPeer) GetContext() context.Context {
return s.ctx
}
49 changes: 38 additions & 11 deletions websocket.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package turnpike

import (
"context"
"crypto/tls"
"fmt"
"net/http"
Expand All @@ -11,12 +12,14 @@ import (
)

type websocketPeer struct {
conn *websocket.Conn
serializer Serializer
messages chan Message
payloadType int
closed bool
sendMutex sync.Mutex
incomeMiddleware func(Message) (Message, error)
conn *websocket.Conn
serializer Serializer
messages chan Message
payloadType int
closed bool
sendMutex sync.Mutex
ctx context.Context
}

func NewWebsocketPeer(serialization Serialization, url string, requestHeader http.Header, tlscfg *tls.Config, dial DialFunc) (Peer, error) {
Expand Down Expand Up @@ -46,10 +49,12 @@ func newWebsocketPeer(url string, reqHeader http.Header, protocol string, serial
return nil, err
}
ep := &websocketPeer{
conn: conn,
messages: make(chan Message, 10),
serializer: serializer,
payloadType: payloadType,
conn: conn,
messages: make(chan Message, 10),
serializer: serializer,
payloadType: payloadType,
incomeMiddleware: nil,
ctx: context.TODO(),
}
go ep.run()

Expand Down Expand Up @@ -79,6 +84,18 @@ func (ep *websocketPeer) Close() error {
return ep.conn.Close()
}

//AddIncomeMiddleware implements preprocess income messages
func (ep *websocketPeer) AddIncomeMiddleware(f func(Message) (Message, error)) {
ep.sendMutex.Lock()
ep.incomeMiddleware = f
ep.sendMutex.Unlock()
}

//GetContext returns context
func (ep *websocketPeer) GetContext() context.Context {
return ep.ctx
}

func (ep *websocketPeer) run() {
for {
// TODO: use conn.NextMessage() and stream
Expand All @@ -102,7 +119,17 @@ func (ep *websocketPeer) run() {
log.Println("error deserializing peer message:", err)
// TODO: handle error
} else {
ep.messages <- msg
if ep.incomeMiddleware == nil {
ep.messages <- msg
} else {
m, err := ep.incomeMiddleware(msg)
if err != nil {
log.Println(err)
} else {
ep.messages <- m
}
}

}
}
}
Expand Down
6 changes: 4 additions & 2 deletions websocket_server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package turnpike

import (
"context"
"fmt"
"net/http"

Expand Down Expand Up @@ -108,10 +109,10 @@ func (s *WebsocketServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
s.handleWebsocket(conn)
s.handleWebsocket(conn, r.Context())
}

func (s *WebsocketServer) handleWebsocket(conn *websocket.Conn) {
func (s *WebsocketServer) handleWebsocket(conn *websocket.Conn, ctx context.Context) {
var serializer Serializer
var payloadType int
if proto, ok := s.protocols[conn.Subprotocol()]; ok {
Expand Down Expand Up @@ -139,6 +140,7 @@ func (s *WebsocketServer) handleWebsocket(conn *websocket.Conn) {
serializer: serializer,
messages: make(chan Message, 10),
payloadType: payloadType,
ctx: ctx,
}
go peer.run()

Expand Down