Skip to content

Commit

Permalink
fix: race condition in hub
Browse files Browse the repository at this point in the history
  • Loading branch information
krzysztofdrys committed Apr 8, 2024
1 parent 7f7ea06 commit 2520156
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 8 deletions.
28 changes: 28 additions & 0 deletions internal/websocket/common/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package common

import (
"bytes"
"sync"
"time"

"github.com/gorilla/websocket"

"github.com/twofas/2fas-server/internal/common/logging"
)

Expand Down Expand Up @@ -43,6 +45,8 @@ type Client struct {

// Buffered channel of outbound messages.
send chan []byte

sendMtx *sync.Mutex
}

// readPump pumps messages from the websocket connection to the hub.
Expand Down Expand Up @@ -133,3 +137,27 @@ func (c *Client) writePump() {
}
}
}

func (c *Client) sendMsg(bb []byte) bool {
c.sendMtx.Lock()
defer c.sendMtx.Unlock()

if c.send == nil {
return false
}

c.send <- bb
return true
}

func (c *Client) close() {
c.sendMtx.Lock()
defer c.sendMtx.Unlock()

if c.send == nil {
return
}

close(c.send)
c.send = nil
}
7 changes: 3 additions & 4 deletions internal/websocket/common/hub.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (h *Hub) unregisterClient(c *Client) {
if !ok {
return
}
close(c.send)
c.close()
if h.isEmpty() {
h.onHubHasNoClients(h.id)
}
Expand All @@ -39,9 +39,8 @@ func (h *Hub) sendToClient(c *Client, msg []byte) {
if !ok {
return
}
select {
case c.send <- msg:
default:
ok = c.sendMsg(msg)
if !ok {
h.unregisterClient(c)
}
}
Expand Down
2 changes: 1 addition & 1 deletion internal/websocket/common/hub_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (h *hubPool) registerClient(channel string, conn *websocket.Conn) (*Client,
defer h.mtx.Unlock()

hub := h.getOrCreateHub(channel)
client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)}
client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256), sendMtx: &sync.Mutex{}}
hub.registerClient(client)

// handler (caller of this method) isn't really interested in hub,
Expand Down
25 changes: 22 additions & 3 deletions internal/websocket/common/hub_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ func TestCreateRemoveConcurrently(t *testing.T) {
hp := newHubPool()
const channelsNo = 100
const clientsPerChannel = 1000
const messagesSentToEachHub = 100

hubs := &sync.Map{}

Expand All @@ -58,21 +59,33 @@ func TestCreateRemoveConcurrently(t *testing.T) {
// This gives us `channelsNo*clientsPerChannel` sub go-routines and `channelsNo` parent goroutines.
// Each of them will call `wg.Done() once and we can't progress until all of them are done.
wg.Add(channelsNo*clientsPerChannel + channelsNo)
// We will close `channelsNo*clientsPerChannel + channelsNo` clients. We create fakeReadPump for each of them and
// wait for it to finish.
wg.Add(channelsNo * clientsPerChannel)

for i := 0; i < channelsNo; i++ {
channelID := fmt.Sprintf("channel-%d", i)

c, h := hp.registerClient(channelID, &websocket.Conn{})
hubs.Store(h, struct{}{})
go fakeReadPump(c.send, &wg)
go func() {
for i := 0; i < messagesSentToEachHub; i++ {
h.broadcastMsg([]byte("test"))
}
}()

go func() {
defer wg.Done()
for j := 0; j < clientsPerChannel; j++ {
c, h := hp.registerClient(channelID, &websocket.Conn{})
hubs.Store(h, struct{}{})
go fakeReadPump(c.send, &wg)

go func() {
h.unregisterClient(c)
wg.Done()
}()
}
_, h := hp.registerClient(channelID, &websocket.Conn{})
hubs.Store(h, struct{}{})
}()
}
wg.Wait()
Expand All @@ -93,3 +106,9 @@ func TestCreateRemoveConcurrently(t *testing.T) {
return true
})
}

func fakeReadPump(c chan []byte, wg *sync.WaitGroup) {
defer wg.Done()
for range c {
}
}

0 comments on commit 2520156

Please sign in to comment.