Skip to content

Commit cefa7d0

Browse files
committed
refactoring of internal/balancer
1 parent 91cf42d commit cefa7d0

12 files changed

+273
-963
lines changed

driver.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import (
2020
internalDiscovery "github.com/ydb-platform/ydb-go-sdk/v3/internal/discovery"
2121
discoveryConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/discovery/config"
2222
"github.com/ydb-platform/ydb-go-sdk/v3/internal/dsn"
23-
"github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint"
2423
internalQuery "github.com/ydb-platform/ydb-go-sdk/v3/internal/query"
2524
queryConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/config"
2625
internalRatelimiter "github.com/ydb-platform/ydb-go-sdk/v3/internal/ratelimiter"
@@ -488,7 +487,7 @@ func (d *Driver) connect(ctx context.Context) (err error) {
488487

489488
d.discovery = xsync.OnceValue(func() (*internalDiscovery.Client, error) {
490489
return internalDiscovery.New(xcontext.ValueOnly(ctx),
491-
d.pool.Get(endpoint.New(d.config.Endpoint())),
490+
d.balancer,
492491
discoveryConfig.New(
493492
append(
494493
// prepend common params from root config

internal/balancer/balancer.go

+88-74
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"google.golang.org/grpc"
1010

1111
"github.com/ydb-platform/ydb-go-sdk/v3/config"
12+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/cluster"
1213
balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config"
1314
"github.com/ydb-platform/ydb-go-sdk/v3/internal/closer"
1415
"github.com/ydb-platform/ydb-go-sdk/v3/internal/conn"
@@ -26,8 +27,6 @@ import (
2627
"github.com/ydb-platform/ydb-go-sdk/v3/trace"
2728
)
2829

29-
var ErrNoEndpoints = xerrors.Wrap(fmt.Errorf("no endpoints"))
30-
3130
type discoveryClient interface {
3231
closer.Closer
3332

@@ -40,9 +39,12 @@ type Balancer struct {
4039
pool *conn.Pool
4140
discoveryClient discoveryClient
4241
discoveryRepeater repeater.Repeater
43-
localDCDetector func(ctx context.Context, endpoints []endpoint.Endpoint) (string, error)
4442

45-
connectionsState atomic.Pointer[connectionsState]
43+
cluster atomic.Pointer[cluster.Cluster]
44+
conns xsync.Map[endpoint.Endpoint, conn.Conn]
45+
banned xsync.Set[endpoint.Endpoint]
46+
47+
localDCDetector func(ctx context.Context, endpoints []endpoint.Endpoint) (string, error)
4648

4749
mu xsync.RWMutex
4850
onApplyDiscoveredEndpoints []func(ctx context.Context, endpoints []endpoint.Info)
@@ -124,19 +126,49 @@ func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context) (err error) {
124126
}
125127

126128
func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, newest []endpoint.Endpoint, localDC string) {
127-
var (
128-
onDone = trace.DriverOnBalancerUpdate(
129-
b.driverConfig.Trace(), &ctx,
130-
stack.FunctionID(
131-
"github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyDiscoveredEndpoints"),
132-
b.config.DetectLocalDC,
133-
)
134-
previous = b.connections().All()
129+
onDone := trace.DriverOnBalancerUpdate(
130+
b.driverConfig.Trace(), &ctx,
131+
stack.FunctionID(
132+
"github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyDiscoveredEndpoints"),
133+
b.config.DetectLocalDC,
135134
)
135+
136+
state := cluster.New(newest,
137+
cluster.From(b.cluster.Load()),
138+
cluster.WithFallback(b.config.AllowFallback),
139+
cluster.WithFilter(func(e endpoint.Info) bool {
140+
if b.config.Filter == nil {
141+
return true
142+
}
143+
144+
return b.config.Filter.Allow(balancerConfig.Info{SelfLocation: localDC}, e)
145+
}),
146+
)
147+
148+
previous := b.cluster.Swap(state)
149+
150+
_, added, dropped := xslices.Diff(previous.All(), newest, func(lhs, rhs endpoint.Endpoint) int {
151+
return strings.Compare(lhs.Address(), rhs.Address())
152+
})
153+
154+
for _, e := range dropped {
155+
c, ok := b.conns.Extract(e)
156+
if !ok {
157+
panic("wrong balancer state")
158+
}
159+
b.pool.Put(ctx, c)
160+
}
161+
162+
for _, e := range added {
163+
cc, err := b.pool.Get(ctx, e)
164+
if err != nil {
165+
b.banned.Add(e)
166+
} else {
167+
b.conns.Set(e, cc)
168+
}
169+
}
170+
136171
defer func() {
137-
_, added, dropped := xslices.Diff(previous, newest, func(lhs, rhs endpoint.Endpoint) int {
138-
return strings.Compare(lhs.Address(), rhs.Address())
139-
})
140172
onDone(
141173
xslices.Transform(newest, func(t endpoint.Endpoint) trace.EndpointInfo { return t }),
142174
xslices.Transform(added, func(t endpoint.Endpoint) trace.EndpointInfo { return t }),
@@ -145,25 +177,13 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, newest []endpoi
145177
)
146178
}()
147179

148-
connections := endpointsToConnections(b.pool, newest)
149-
for _, c := range connections {
150-
b.pool.Allow(ctx, c)
151-
c.Endpoint().Touch()
152-
}
153-
154-
info := balancerConfig.Info{SelfLocation: localDC}
155-
state := newConnectionsState(connections, b.config.Filter, info, b.config.AllowFallback)
156-
157-
endpointsInfo := make([]endpoint.Info, len(newest))
158-
for i, e := range newest {
159-
endpointsInfo[i] = e
160-
}
161-
162-
b.connectionsState.Store(state)
180+
endpoints := xslices.Transform(newest, func(e endpoint.Endpoint) endpoint.Info {
181+
return e
182+
})
163183

164184
b.mu.WithLock(func() {
165185
for _, onApplyDiscoveredEndpoints := range b.onApplyDiscoveredEndpoints {
166-
onApplyDiscoveredEndpoints(ctx, endpointsInfo)
186+
onApplyDiscoveredEndpoints(ctx, endpoints)
167187
}
168188
})
169189
}
@@ -212,18 +232,20 @@ func New(
212232
onDone(finalErr)
213233
}()
214234

235+
cc, err := pool.Get(ctx, endpoint.New(driverConfig.Endpoint()))
236+
if err != nil {
237+
return nil, xerrors.WithStackTrace(err)
238+
}
239+
215240
b = &Balancer{
216-
driverConfig: driverConfig,
217-
pool: pool,
218-
discoveryClient: internalDiscovery.New(ctx, pool.Get(
219-
endpoint.New(driverConfig.Endpoint()),
220-
), discoveryConfig),
241+
config: balancerConfig.Config{},
242+
driverConfig: driverConfig,
243+
pool: pool,
244+
discoveryClient: internalDiscovery.New(ctx, cc, discoveryConfig),
221245
localDCDetector: detectLocalDC,
222246
}
223247

224-
if config := driverConfig.Balancer(); config == nil {
225-
b.config = balancerConfig.Config{}
226-
} else {
248+
if config := driverConfig.Balancer(); config != nil {
227249
b.config = *config
228250
}
229251

@@ -289,10 +311,10 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
289311
defer func() {
290312
if err == nil {
291313
if cc.GetState() == conn.Banned {
292-
b.pool.Allow(ctx, cc)
314+
b.banned.Remove(cc.Endpoint())
293315
}
294316
} else if conn.IsBadConn(err, b.driverConfig.ExcludeGRPCCodesForPessimization()...) {
295-
b.pool.Ban(ctx, cc, err)
317+
b.banned.Add(cc.Endpoint())
296318
}
297319
}()
298320

@@ -319,53 +341,45 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
319341
return nil
320342
}
321343

322-
func (b *Balancer) connections() *connectionsState {
323-
return b.connectionsState.Load()
324-
}
325-
326-
func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) {
344+
func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, finalErr error) {
327345
onDone := trace.DriverOnBalancerChooseEndpoint(
328346
b.driverConfig.Trace(), &ctx,
329347
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).getConn"),
330348
)
349+
331350
defer func() {
332-
if err == nil {
351+
if finalErr == nil {
333352
onDone(c.Endpoint(), nil)
334353
} else {
335-
onDone(nil, err)
354+
if b.cluster.Load().Availability() < 0.5 && b.discoveryRepeater != nil {
355+
b.discoveryRepeater.Force()
356+
}
357+
358+
onDone(nil, finalErr)
336359
}
337360
}()
338361

339-
if err = ctx.Err(); err != nil {
340-
return nil, xerrors.WithStackTrace(err)
341-
}
362+
for attempts := 1; ; attempts++ {
363+
if err := ctx.Err(); err != nil {
364+
return nil, xerrors.WithStackTrace(err)
365+
}
342366

343-
var (
344-
state = b.connections()
345-
failedCount int
346-
)
367+
state := b.cluster.Load()
347368

348-
defer func() {
349-
if failedCount*2 > state.PreferredCount() && b.discoveryRepeater != nil {
350-
b.discoveryRepeater.Force()
369+
e, err := state.Next(ctx)
370+
if err != nil {
371+
return nil, xerrors.WithStackTrace(
372+
fmt.Errorf("%w: cannot get connection from Balancer after %d attempts", cluster.ErrNoEndpoints, attempts),
373+
)
351374
}
352-
}()
353-
354-
c, failedCount = state.GetConnection(ctx)
355-
if c == nil {
356-
return nil, xerrors.WithStackTrace(
357-
fmt.Errorf("%w: cannot get connection from Balancer after %d attempts", ErrNoEndpoints, failedCount),
358-
)
359-
}
360375

361-
return c, nil
362-
}
376+
cc, err := b.pool.Get(ctx, e)
377+
if err == nil {
378+
return cc, nil
379+
}
363380

364-
func endpointsToConnections(p *conn.Pool, endpoints []endpoint.Endpoint) []conn.Conn {
365-
conns := make([]conn.Conn, 0, len(endpoints))
366-
for _, e := range endpoints {
367-
conns = append(conns, p.Get(e))
381+
if b.cluster.CompareAndSwap(state, cluster.Without(b.cluster.Load(), e)) {
382+
b.banned.Add(e)
383+
}
368384
}
369-
370-
return conns
371385
}

0 commit comments

Comments
 (0)