diff --git a/peer.go b/peer.go index 88f35c5..ebf9f07 100644 --- a/peer.go +++ b/peer.go @@ -1,6 +1,7 @@ package turnpike import ( + "context" "fmt" "time" ) @@ -24,6 +25,12 @@ type Peer interface { // Receive returns a channel of messages coming from the peer. Receive() <-chan Message + + //AddIncomeMiddleware implements preprocess income messages + AddIncomeMiddleware(f func(Message) (Message, error)) + + //GetContext returns context + GetContext() context.Context } // GetMessageTimeout is a convenience function to get a single message from a diff --git a/session.go b/session.go index 42564c4..57c14c5 100644 --- a/session.go +++ b/session.go @@ -1,6 +1,7 @@ package turnpike import ( + "context" "fmt" ) @@ -27,10 +28,12 @@ func localPipe() (*localPeer, *localPeer) { a := &localPeer{ incoming: bToA, outgoing: aToB, + ctx: context.TODO(), } b := &localPeer{ incoming: aToB, outgoing: bToA, + ctx: context.TODO(), } return a, b @@ -39,6 +42,7 @@ func localPipe() (*localPeer, *localPeer) { type localPeer struct { outgoing chan<- Message incoming <-chan Message + ctx context.Context } func (s *localPeer) Receive() <-chan Message { @@ -54,3 +58,12 @@ func (s *localPeer) Close() error { close(s.outgoing) return nil } + +//todo +func (s *localPeer) AddIncomeMiddleware(f func(Message) (Message, error)) { + +} + +func (s *localPeer) GetContext() context.Context { + return s.ctx +} diff --git a/websocket.go b/websocket.go index 0ae460a..cc28373 100644 --- a/websocket.go +++ b/websocket.go @@ -1,6 +1,7 @@ package turnpike import ( + "context" "crypto/tls" "fmt" "net/http" @@ -11,12 +12,14 @@ import ( ) type websocketPeer struct { - conn *websocket.Conn - serializer Serializer - messages chan Message - payloadType int - closed bool - sendMutex sync.Mutex + incomeMiddleware func(Message) (Message, error) + conn *websocket.Conn + serializer Serializer + messages chan Message + payloadType int + closed bool + sendMutex sync.Mutex + ctx context.Context } func NewWebsocketPeer(serialization Serialization, url string, requestHeader http.Header, tlscfg *tls.Config, dial DialFunc) (Peer, error) { @@ -46,10 +49,12 @@ func newWebsocketPeer(url string, reqHeader http.Header, protocol string, serial return nil, err } ep := &websocketPeer{ - conn: conn, - messages: make(chan Message, 10), - serializer: serializer, - payloadType: payloadType, + conn: conn, + messages: make(chan Message, 10), + serializer: serializer, + payloadType: payloadType, + incomeMiddleware: nil, + ctx: context.TODO(), } go ep.run() @@ -79,6 +84,18 @@ func (ep *websocketPeer) Close() error { return ep.conn.Close() } +//AddIncomeMiddleware implements preprocess income messages +func (ep *websocketPeer) AddIncomeMiddleware(f func(Message) (Message, error)) { + ep.sendMutex.Lock() + ep.incomeMiddleware = f + ep.sendMutex.Unlock() +} + +//GetContext returns context +func (ep *websocketPeer) GetContext() context.Context { + return ep.ctx +} + func (ep *websocketPeer) run() { for { // TODO: use conn.NextMessage() and stream @@ -102,7 +119,17 @@ func (ep *websocketPeer) run() { log.Println("error deserializing peer message:", err) // TODO: handle error } else { - ep.messages <- msg + if ep.incomeMiddleware == nil { + ep.messages <- msg + } else { + m, err := ep.incomeMiddleware(msg) + if err != nil { + log.Println(err) + } else { + ep.messages <- m + } + } + } } } diff --git a/websocket_server.go b/websocket_server.go index 14849a9..d5d748c 100644 --- a/websocket_server.go +++ b/websocket_server.go @@ -1,6 +1,7 @@ package turnpike import ( + "context" "fmt" "net/http" @@ -108,10 +109,10 @@ func (s *WebsocketServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusBadRequest) return } - s.handleWebsocket(conn) + s.handleWebsocket(conn, r.Context()) } -func (s *WebsocketServer) handleWebsocket(conn *websocket.Conn) { +func (s *WebsocketServer) handleWebsocket(conn *websocket.Conn, ctx context.Context) { var serializer Serializer var payloadType int if proto, ok := s.protocols[conn.Subprotocol()]; ok { @@ -139,6 +140,7 @@ func (s *WebsocketServer) handleWebsocket(conn *websocket.Conn) { serializer: serializer, messages: make(chan Message, 10), payloadType: payloadType, + ctx: ctx, } go peer.run()