diff --git a/wsproxy/websocket_proxy.go b/wsproxy/websocket_proxy.go index 7092162..601b6be 100644 --- a/wsproxy/websocket_proxy.go +++ b/wsproxy/websocket_proxy.go @@ -261,28 +261,39 @@ func (p *Proxy) proxy(w http.ResponseWriter, r *http.Request) { } } }() - // ping write loop - if p.pingInterval > 0 && p.pingWait > 0 && p.pongWait > 0 { - go func() { + + // write loop -- take bytes from http and write to websocket + dataWriteChan := make(chan []byte, 32) + go func() { + var pingChan <-chan time.Time + if p.pingInterval > 0 && p.pingWait > 0 && p.pongWait > 0 { ticker := time.NewTicker(p.pingInterval) + pingChan = ticker.C defer func() { ticker.Stop() conn.Close() }() - for { - select { - case <-ctx.Done(): - p.logger.Debugln("ping loop done") + } else { + pingChan = make(chan time.Time) + } + for { + select { + case <-ctx.Done(): + p.logger.Debugln("write loop done") + return + case <-pingChan: + if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return + } + case data := <-dataWriteChan: + if err = conn.WriteMessage(websocket.TextMessage, data); err != nil { + p.logger.Warnln("[write] error writing websocket message:", err) return - case <-ticker.C: - conn.SetWriteDeadline(time.Now().Add(p.pingWait)) - if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { - return - } } } - }() - } + } + }() + // write loop -- take messages from response and write to websocket scanner := bufio.NewScanner(responseBodyR) @@ -299,10 +310,7 @@ func (p *Proxy) proxy(w http.ResponseWriter, r *http.Request) { continue } p.logger.Debugln("[write] scanned", scanner.Text()) - if err = conn.WriteMessage(websocket.TextMessage, scanner.Bytes()); err != nil { - p.logger.Warnln("[write] error writing websocket message:", err) - return - } + dataWriteChan <- scanner.Bytes() } if err := scanner.Err(); err != nil { p.logger.Warnln("scanner err:", err)