Skip to content

Commit

Permalink
Fix context cancel leak
Browse files Browse the repository at this point in the history
  • Loading branch information
kevburnsjr committed Jan 5, 2025
1 parent f0d5f88 commit f78448e
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 31 deletions.
31 changes: 23 additions & 8 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,10 @@ func (a *Agent) resolvePeerGossipSeed() (gossip []string, err error) {
gossip = append(gossip, a.hostConfig.Gossip.AdvertiseAddress)
continue
}
res, err := a.grpcClientPool.get(peerApiAddr).Probe(raftCtx(), &internal.ProbeRequest{})
var res *internal.ProbeResponse
ctx, cancel := context.WithTimeout(context.Background(), raftTimeout)
defer cancel()
res, err = a.grpcClientPool.get(peerApiAddr).Probe(ctx, &internal.ProbeRequest{})
if err == nil && res != nil {
gossip = append(gossip, res.GossipAdvertiseAddress)
} else if err != nil && !strings.HasSuffix(err.Error(), `connect: connection refused"`) {
Expand Down Expand Up @@ -498,7 +501,9 @@ func (a *Agent) resolvePrimeMembership() (members map[uint64]string, init bool,
HostId: a.hostID(),
}
} else {
info, err = a.grpcClientPool.get(apiAddr).Info(raftCtx(), &internal.InfoRequest{})
ctx, cancel := context.WithTimeout(context.Background(), raftTimeout)
defer cancel()
info, err = a.grpcClientPool.get(apiAddr).Info(ctx, &internal.InfoRequest{})
if err != nil {
return
}
Expand Down Expand Up @@ -539,7 +544,9 @@ func (a *Agent) resolvePrimeMembership() (members map[uint64]string, init bool,
continue
}
if _, ok := uninitialized[apiAddr]; !ok {
res, err = a.grpcClientPool.get(apiAddr).Members(raftCtx(), &internal.MembersRequest{})
ctx, cancel := context.WithTimeout(context.Background(), raftTimeout)
defer cancel()
res, err = a.grpcClientPool.get(apiAddr).Members(ctx, &internal.MembersRequest{})
if err != nil {
return
}
Expand Down Expand Up @@ -588,7 +595,9 @@ func (a *Agent) joinPrimeShard() (replicaID uint64, err error) {
a.log.Debugf("Joining prime shard")
var res *internal.JoinResponse
for _, peerApiAddr := range a.peers {
res, err = a.grpcClientPool.get(peerApiAddr).Join(raftCtx(), &internal.JoinRequest{
ctx, cancel := context.WithTimeout(context.Background(), raftTimeout)
defer cancel()
res, err = a.grpcClientPool.get(peerApiAddr).Join(ctx, &internal.JoinRequest{
HostId: a.hostID(),
IsNonVoting: a.replicaConfig.IsNonVoting,
})
Expand Down Expand Up @@ -686,20 +695,24 @@ func (a *Agent) joinPrimeReplica(hostID string, shardID uint64, isNonVoting bool
}

func (a *Agent) joinShardReplica(hostID string, shardID, replicaID uint64, isNonVoting bool) (res uint64, err error) {
m, err := a.host.SyncGetShardMembership(raftCtx(), shardID)
ctx, cancel := context.WithTimeout(context.Background(), raftTimeout)
defer cancel()
m, err := a.host.SyncGetShardMembership(ctx, shardID)
if err != nil {
return
}
ctx, cancel = context.WithTimeout(context.Background(), raftTimeout)
defer cancel()
if isNonVoting {
if _, ok := m.NonVotings[replicaID]; ok {
return replicaID, nil
}
err = a.host.SyncRequestAddNonVoting(raftCtx(), shardID, replicaID, hostID, m.ConfigChangeID)
err = a.host.SyncRequestAddNonVoting(ctx, shardID, replicaID, hostID, m.ConfigChangeID)
} else {
if _, ok := m.Nodes[replicaID]; ok {
return replicaID, nil
}
err = a.host.SyncRequestAddReplica(raftCtx(), shardID, replicaID, hostID, m.ConfigChangeID)
err = a.host.SyncRequestAddReplica(ctx, shardID, replicaID, hostID, m.ConfigChangeID)
}
if err != nil {
return
Expand All @@ -724,7 +737,9 @@ func (a *Agent) parseMeta(nhid string) (apiAddr string, err error) {
}

func (a *Agent) primePropose(cmd []byte) (Result, error) {
return a.host.SyncPropose(raftCtx(), a.host.GetNoOPSession(a.replicaConfig.ShardID), cmd)
ctx, cancel := context.WithTimeout(context.Background(), raftTimeout)
defer cancel()
return a.host.SyncPropose(ctx, a.host.GetNoOPSession(a.replicaConfig.ShardID), cmd)
}

// primeInit proposes addition of initial cluster state to prime shard
Expand Down
26 changes: 18 additions & 8 deletions agent_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,13 +290,19 @@ func runAgentSubTest(t *testing.T, agents []*Agent, shard Shard, op string, stal
require.NotNil(t, client)
if op == "update" && stale {
start := time.Now()
err = client.Commit(raftCtx(), shard.ID, bytes.Repeat([]byte("test"), i+1))
ctx, cancel := context.WithTimeout(context.Background(), raftTimeout)
defer cancel()
err = client.Commit(ctx, shard.ID, bytes.Repeat([]byte("test"), i+1))
update_time += time.Since(start)
updates++
} else if op == "update" && !stale {
val, _, err = client.Apply(raftCtx(), shard.ID, bytes.Repeat([]byte("test"), i+1))
ctx, cancel := context.WithTimeout(context.Background(), raftTimeout)
defer cancel()
val, _, err = client.Apply(ctx, shard.ID, bytes.Repeat([]byte("test"), i+1))
} else if op == "query" {
val, _, err = client.Read(raftCtx(), shard.ID, bytes.Repeat([]byte("test"), i+1), stale)
ctx, cancel := context.WithTimeout(context.Background(), raftTimeout)
defer cancel()
val, _, err = client.Read(ctx, shard.ID, bytes.Repeat([]byte("test"), i+1), stale)
assert.Nil(t, err)
} else if op == "watch" {
res := make(chan *Result)
Expand All @@ -317,7 +323,9 @@ func runAgentSubTest(t *testing.T, agents []*Agent, shard Shard, op string, stal
}
}
}()
err = client.Watch(raftCtx(), shard.ID, bytes.Repeat([]byte("test"), i+1), res, stale)
ctx, cancel := context.WithTimeout(context.Background(), raftTimeout)
defer cancel()
err = client.Watch(ctx, shard.ID, bytes.Repeat([]byte("test"), i+1), res, stale)
close(done)
wg.Wait()
assert.Equal(t, uint64((i+1)*4), n)
Expand Down Expand Up @@ -351,12 +359,14 @@ func runAgentSubTestByShard(t *testing.T, agents []*Agent, shard Shard, op strin
client = a.Client(shard.ID, WithWriteToLeader())
}
require.NotNil(t, client)
ctx, cancel := context.WithTimeout(context.Background(), raftTimeout)
defer cancel()
if op == "update" && stale {
err = client.Commit(raftCtx(), bytes.Repeat([]byte("test"), i+1))
err = client.Commit(ctx, bytes.Repeat([]byte("test"), i+1))
} else if op == "update" && !stale {
val, _, err = client.Apply(raftCtx(), bytes.Repeat([]byte("test"), i+1))
val, _, err = client.Apply(ctx, bytes.Repeat([]byte("test"), i+1))
} else if op == "query" {
val, _, err = client.Read(raftCtx(), bytes.Repeat([]byte("test"), i+1), stale)
val, _, err = client.Read(ctx, bytes.Repeat([]byte("test"), i+1), stale)
} else if op == "watch" {
res := make(chan *Result)
done := make(chan bool)
Expand All @@ -376,7 +386,7 @@ func runAgentSubTestByShard(t *testing.T, agents []*Agent, shard Shard, op strin
}
}
}()
err = client.Watch(raftCtx(), bytes.Repeat([]byte("test"), i+1), res, stale)
err = client.Watch(ctx, bytes.Repeat([]byte("test"), i+1), res, stale)
close(done)
wg.Wait()
assert.Equal(t, uint64((i+1)*4), n)
Expand Down
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ require (
golang.org/x/net v0.33.0 // indirect
golang.org/x/sys v0.28.0 // indirect
golang.org/x/text v0.21.0 // indirect
google.golang.org/genproto v0.0.0-20241219192143-6b3ec007d9bb // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20241219192143-6b3ec007d9bb // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
Expand Down
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,6 @@ google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoA
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
google.golang.org/genproto v0.0.0-20210624195500-8bfb893ecb84/go.mod h1:SzzZ/N+nwJDaO1kznhnlzqS8ocJICar6hYhVyhi++24=
google.golang.org/genproto v0.0.0-20241219192143-6b3ec007d9bb h1:JGs+s1Q6osip3cDY197L1HmkuPn8wPp9Hfy9jl+Uz+U=
google.golang.org/genproto v0.0.0-20241219192143-6b3ec007d9bb/go.mod h1:o8GgNarfULyZPNaIY8RDfXM7AZcmcKC/tbMWp/ZOFDw=
google.golang.org/genproto/googleapis/rpc v0.0.0-20241219192143-6b3ec007d9bb h1:3oy2tynMOP1QbTC0MsNNAV+Se8M2Bd0A5+x1QHyw+pI=
google.golang.org/genproto/googleapis/rpc v0.0.0-20241219192143-6b3ec007d9bb/go.mod h1:lcTa1sDdWEIHMWlITnIczmw5w60CF9ffkb8Z+DVmmjA=
google.golang.org/grpc v1.12.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
Expand Down
4 changes: 3 additions & 1 deletion host_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,9 @@ func (c *hostController) requestShardJoin(members map[uint64]string, shardID, re
c.agent.log.Warningf(`Host not found %s`, hostID)
continue
}
res, err = c.agent.grpcClientPool.get(host.ApiAddress).Add(raftCtx(), &internal.AddRequest{
ctx, cancel := context.WithTimeout(context.Background(), raftTimeout)
defer cancel()
res, err = c.agent.grpcClientPool.get(host.ApiAddress).Add(ctx, &internal.AddRequest{
HostId: c.agent.hostID(),
ShardId: shardID,
ReplicaId: replicaID,
Expand Down
13 changes: 2 additions & 11 deletions util.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ var (
}
DefaultReplicaConfig = ReplicaConfig{
CheckQuorum: true,
CompactionOverhead: 1000,
CompactionOverhead: 10000,
ElectionRTT: 100,
EntryCompressionType: config.Snappy,
HeartbeatRTT: 10,
OrderedConfigChange: true,
Quiesce: false,
SnapshotCompressionType: config.Snappy,
SnapshotEntries: 1000,
SnapshotEntries: 10000,
}
)

Expand Down Expand Up @@ -240,15 +240,6 @@ func base36Encode(id uint64) string {
return strconv.FormatUint(id, 36)
}

func raftCtx(ctxs ...context.Context) (ctx context.Context) {
ctx = context.Background()
if len(ctxs) > 0 {
ctx = ctxs[0]
}
ctx, _ = context.WithTimeout(ctx, raftTimeout)
return
}

type compositeRaftEventListener struct {
listeners []raftio.IRaftEventListener
}
Expand Down

0 comments on commit f78448e

Please sign in to comment.