Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/golang-test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ jobs:
- arch: "386"
raceFlag: ""
- arch: "amd64"
raceFlag: "-race"
raceFlag: "-race -v"
runs-on: ubuntu-22.04
steps:
- name: Install Go
Expand Down Expand Up @@ -258,6 +258,7 @@ jobs:
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test ${{ matrix.raceFlag }} \
-tags devcert \
-exec 'sudo' \
-timeout 10m ./relay/... ./shared/relay/...

Expand Down
4 changes: 3 additions & 1 deletion client/internal/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,9 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return wrapErr(err)
}

relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), &relayClient.ManagerOpts{
MTU: engineConfig.MTU,
})
c.statusRecorder.SetRelayMgr(relayManager)
if len(relayURLs) > 0 {
if token != nil {
Expand Down
16 changes: 9 additions & 7 deletions client/internal/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (

"github.com/golang/mock/gomock"
"github.com/google/uuid"
"github.com/netbirdio/netbird/client/internal/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -25,7 +24,10 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"

"github.com/netbirdio/netbird/client/internal/stdnet"

"github.com/netbirdio/management-integrations/integrations"

"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
Expand Down Expand Up @@ -227,7 +229,7 @@ func TestEngine_SSH(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), &relayClient.ManagerOpts{MTU: iface.DefaultMTU})
engine := NewEngine(
ctx, cancel,
&signal.MockClient{},
Expand Down Expand Up @@ -373,7 +375,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), &relayClient.ManagerOpts{MTU: iface.DefaultMTU})
engine := NewEngine(
ctx, cancel,
&signal.MockClient{},
Expand Down Expand Up @@ -600,7 +602,7 @@ func TestEngine_Sync(t *testing.T) {
}
return nil
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), &relayClient.ManagerOpts{MTU: iface.DefaultMTU})
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{
WgIfaceName: "utun103",
WgAddr: "100.64.0.1/24",
Expand Down Expand Up @@ -765,7 +767,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)

relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), &relayClient.ManagerOpts{MTU: iface.DefaultMTU})
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: wgIfaceName,
WgAddr: wgAddr,
Expand Down Expand Up @@ -967,7 +969,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)

relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), &relayClient.ManagerOpts{MTU: iface.DefaultMTU})
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: wgIfaceName,
WgAddr: wgAddr,
Expand Down Expand Up @@ -1499,7 +1501,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
MTU: iface.DefaultMTU,
}

relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), &relayClient.ManagerOpts{MTU: iface.DefaultMTU})
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
e.ctx = ctx
return e, err
Expand Down
1 change: 1 addition & 0 deletions relay/server/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ func (r *Relay) Accept(conn net.Conn) {
storeTime := time.Now()
if isReconnection := r.store.AddPeer(peer); isReconnection {
r.metrics.RecordPeerReconnection()
r.notifier.PeerWentOffline(peer.ID())
}
r.notifier.PeerCameOnline(peer.ID())

Expand Down
123 changes: 111 additions & 12 deletions shared/relay/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ import (
"go.opentelemetry.io/otel"

"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/shared/relay/auth/allow"
"github.com/netbirdio/netbird/shared/relay/auth/hmac"
"github.com/netbirdio/netbird/shared/relay/messages"
"github.com/netbirdio/netbird/util"

"github.com/netbirdio/netbird/relay/server"
)

var (
Expand Down Expand Up @@ -312,7 +312,7 @@ func TestBindToUnavailabePeer(t *testing.T) {
clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU)
err = clientAlice.Connect(ctx)
if err != nil {
t.Errorf("failed to connect to server: %s", err)
t.Fatalf("failed to connect to server: %s", err)
}
_, err = clientAlice.OpenConn(ctx, "bob")
if err == nil {
Expand Down Expand Up @@ -364,7 +364,7 @@ func TestBindReconnect(t *testing.T) {
clientBob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU)
err = clientBob.Connect(ctx)
if err != nil {
t.Errorf("failed to connect to server: %s", err)
t.Fatalf("failed to connect to server: %s", err)
}

_, err = clientAlice.OpenConn(ctx, "bob")
Expand All @@ -374,7 +374,7 @@ func TestBindReconnect(t *testing.T) {

chBob, err := clientBob.OpenConn(ctx, "alice")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
t.Fatalf("failed to bind channel: %s", err)
}

log.Infof("closing client Alice")
Expand All @@ -386,12 +386,12 @@ func TestBindReconnect(t *testing.T) {
clientAlice = NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU)
err = clientAlice.Connect(ctx)
if err != nil {
t.Errorf("failed to connect to server: %s", err)
t.Fatalf("failed to connect to server: %s", err)
}

chAlice, err := clientAlice.OpenConn(ctx, "bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
t.Fatalf("failed to bind channel: %s", err)
}

testString := "hello alice, I am bob"
Expand All @@ -402,7 +402,7 @@ func TestBindReconnect(t *testing.T) {

chBob, err = clientBob.OpenConn(ctx, "alice")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
t.Fatalf("failed to bind channel: %s", err)
}

_, err = chBob.Write([]byte(testString))
Expand All @@ -427,6 +427,105 @@ func TestBindReconnect(t *testing.T) {
}
}

func TestBindReconnectRace(t *testing.T) {
ctx := context.Background()

srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(serverCfg)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()

defer func() {
log.Infof("closing server")
err := srv.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()

// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}

clientBob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU)
err = clientBob.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer clientBob.Close()

// Run the reconnection scenario multiple times to expose the race
failures := 0
iterations := 1000

for i := 0; i < iterations; i++ {
log.Infof("Iteration %d/%d", i+1, iterations)

// Alice connects
clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU)
err = clientAlice.Connect(ctx)
if err != nil {
t.Fatalf("iteration %d: failed to connect alice: %s", i, err)
}

// Bob opens connection to Alice
_, err = clientBob.OpenConn(ctx, "alice")
if err != nil {
t.Fatalf("iteration %d: failed to open conn from bob: %s", i, err)
}

// Close Alice immediately
err = clientAlice.Close()
if err != nil {
t.Errorf("iteration %d: failed to close alice: %s", i, err)
}

// Reconnect Alice immediately (this is where the race occurs)
clientAlice = NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU)
err = clientAlice.Connect(ctx)
if err != nil {
t.Fatalf("iteration %d: failed to reconnect alice: %s", i, err)
}

// Bob tries to open a new connection to the reconnected Alice
// Without the fix, this will sometimes fail with "connection already exists"
// because Bob still has the old connection in its map
_, err = clientBob.OpenConn(ctx, "alice")
if err != nil {
log.Errorf("iteration %d: RACE DETECTED - failed to open new conn after reconnect: %s", i, err)
failures++
}

// Clean up
clientAlice.Close()

// Close Bob's connection to Alice to prepare for next iteration
clientBob.mu.Lock()
aliceID := messages.HashID("alice")
if container, ok := clientBob.conns[aliceID]; ok {
container.close()
delete(clientBob.conns, aliceID)
}
clientBob.mu.Unlock()
}
Comment on lines +511 to +519
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Encapsulation violation: Test accesses private client internals.

The test directly manipulates clientBob.mu and clientBob.conns, breaking encapsulation. This makes the test fragile and tightly coupled to implementation details.

Consider one of these approaches:

  1. Preferred: Add a public method to the Client type to close a specific connection by peer ID:
// In client.go
func (c *Client) CloseConnToPeer(peerID string) error {
    c.mu.Lock()
    defer c.mu.Unlock()
    hashedID := messages.HashID(peerID)
    if container, ok := c.conns[hashedID]; ok {
        container.close()
        delete(c.conns, hashedID)
    }
    return nil
}

Then use it in the test:

-		// Close Bob's connection to Alice to prepare for next iteration
-		clientBob.mu.Lock()
-		aliceID := messages.HashID("alice")
-		if container, ok := clientBob.conns[aliceID]; ok {
-			container.close()
-			delete(clientBob.conns, aliceID)
-		}
-		clientBob.mu.Unlock()
+		// Close Bob's connection to Alice to prepare for next iteration
+		err = clientBob.CloseConnToPeer("alice")
+		if err != nil {
+			t.Errorf("iteration %d: failed to close Bob's conn to Alice: %s", i, err)
+		}
  1. Alternative: Create a fresh Bob client for each iteration instead of manually cleaning state.
🤖 Prompt for AI Agents
shared/relay/client/client_test.go around lines 511-519: the test directly
accesses clientBob.mu and clientBob.conns which breaks encapsulation; add a
public method on Client (e.g., CloseConnToPeer(peerID string) error) in
client.go that hashes the peer ID, locks the client, finds the connection
container, calls its close method and deletes the map entry, then call that new
method from the test to close Bob’s connection to "alice" (alternatively, create
a fresh Bob client per iteration).


if failures > 0 {
t.Errorf("Race condition detected in %d out of %d iterations (%.1f%%)",
failures, iterations, float64(failures)/float64(iterations)*100)
} else {
log.Infof("No race detected in %d iterations (fix is working or race didn't trigger)", iterations)
}
}

func TestCloseConn(t *testing.T) {
ctx := context.Background()

Expand Down Expand Up @@ -459,18 +558,18 @@ func TestCloseConn(t *testing.T) {
bob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU)
err = bob.Connect(ctx)
if err != nil {
t.Errorf("failed to connect to server: %s", err)
t.Fatalf("failed to connect to server: %s", err)
}

clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU)
err = clientAlice.Connect(ctx)
if err != nil {
t.Errorf("failed to connect to server: %s", err)
t.Fatalf("failed to connect to server: %s", err)
}

conn, err := clientAlice.OpenConn(ctx, "bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
t.Fatalf("failed to bind channel: %s", err)
}

log.Infof("closing connection")
Expand Down Expand Up @@ -532,7 +631,7 @@ func TestCloseRelayConn(t *testing.T) {

conn, err := clientAlice.OpenConn(ctx, "bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
t.Fatalf("failed to bind channel: %s", err)
}

_ = clientAlice.relayConn.Close()
Expand Down
Loading
Loading