99 "google.golang.org/grpc"
1010
1111 "github.com/ydb-platform/ydb-go-sdk/v3/config"
12+ "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/cluster"
1213 balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config"
1314 "github.com/ydb-platform/ydb-go-sdk/v3/internal/closer"
1415 "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn"
@@ -26,8 +27,6 @@ import (
2627 "github.com/ydb-platform/ydb-go-sdk/v3/trace"
2728)
2829
29- var ErrNoEndpoints = xerrors .Wrap (fmt .Errorf ("no endpoints" ))
30-
3130type discoveryClient interface {
3231 closer.Closer
3332
@@ -40,9 +39,12 @@ type Balancer struct {
4039 pool * conn.Pool
4140 discoveryClient discoveryClient
4241 discoveryRepeater repeater.Repeater
43- localDCDetector func (ctx context.Context , endpoints []endpoint.Endpoint ) (string , error )
4442
45- connectionsState atomic.Pointer [connectionsState ]
43+ cluster atomic.Pointer [cluster.Cluster ]
44+ conns xsync.Map [endpoint.Endpoint , conn.Conn ]
45+ banned xsync.Set [endpoint.Endpoint ]
46+
47+ localDCDetector func (ctx context.Context , endpoints []endpoint.Endpoint ) (string , error )
4648
4749 mu xsync.RWMutex
4850 onApplyDiscoveredEndpoints []func (ctx context.Context , endpoints []endpoint.Info )
@@ -124,19 +126,48 @@ func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context) (err error) {
124126}
125127
126128func (b * Balancer ) applyDiscoveredEndpoints (ctx context.Context , newest []endpoint.Endpoint , localDC string ) {
127- var (
128- onDone = trace .DriverOnBalancerUpdate (
129- b .driverConfig .Trace (), & ctx ,
130- stack .FunctionID (
131- "github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyDiscoveredEndpoints" ),
132- b .config .DetectLocalDC ,
133- )
134- previous = b .connections ().All ()
129+ onDone := trace .DriverOnBalancerUpdate (
130+ b .driverConfig .Trace (), & ctx ,
131+ stack .FunctionID (
132+ "github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyDiscoveredEndpoints" ),
133+ b .config .DetectLocalDC ,
135134 )
135+
136+ state := cluster .New (newest ,
137+ cluster .WithFilter (func (e endpoint.Info ) bool {
138+ if b .config .Filter == nil {
139+ return true
140+ }
141+
142+ return b .config .Filter .Allow (balancerConfig.Info {SelfLocation : localDC }, e )
143+ }),
144+ cluster .WithFallback (b .config .AllowFallback ),
145+ )
146+
147+ previous := b .cluster .Swap (state )
148+
149+ _ , added , dropped := xslices .Diff (previous .All (), newest , func (lhs , rhs endpoint.Endpoint ) int {
150+ return strings .Compare (lhs .Address (), rhs .Address ())
151+ })
152+
153+ for _ , e := range dropped {
154+ c , ok := b .conns .Extract (e )
155+ if ! ok {
156+ panic ("wrong balancer state" )
157+ }
158+ b .pool .Put (ctx , c )
159+ }
160+
161+ for _ , e := range added {
162+ cc , err := b .pool .Get (ctx , e )
163+ if err != nil {
164+ b .banned .Add (e )
165+ } else {
166+ b .conns .Set (e , cc )
167+ }
168+ }
169+
136170 defer func () {
137- _ , added , dropped := xslices .Diff (previous , newest , func (lhs , rhs endpoint.Endpoint ) int {
138- return strings .Compare (lhs .Address (), rhs .Address ())
139- })
140171 onDone (
141172 xslices .Transform (newest , func (t endpoint.Endpoint ) trace.EndpointInfo { return t }),
142173 xslices .Transform (added , func (t endpoint.Endpoint ) trace.EndpointInfo { return t }),
@@ -145,25 +176,13 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, newest []endpoi
145176 )
146177 }()
147178
148- connections := endpointsToConnections (b .pool , newest )
149- for _ , c := range connections {
150- b .pool .Allow (ctx , c )
151- c .Endpoint ().Touch ()
152- }
153-
154- info := balancerConfig.Info {SelfLocation : localDC }
155- state := newConnectionsState (connections , b .config .Filter , info , b .config .AllowFallback )
156-
157- endpointsInfo := make ([]endpoint.Info , len (newest ))
158- for i , e := range newest {
159- endpointsInfo [i ] = e
160- }
161-
162- b .connectionsState .Store (state )
179+ endpoints := xslices .Transform (newest , func (e endpoint.Endpoint ) endpoint.Info {
180+ return e
181+ })
163182
164183 b .mu .WithLock (func () {
165184 for _ , onApplyDiscoveredEndpoints := range b .onApplyDiscoveredEndpoints {
166- onApplyDiscoveredEndpoints (ctx , endpointsInfo )
185+ onApplyDiscoveredEndpoints (ctx , endpoints )
167186 }
168187 })
169188}
@@ -212,18 +231,20 @@ func New(
212231 onDone (finalErr )
213232 }()
214233
234+ cc , err := pool .Get (ctx , endpoint .New (driverConfig .Endpoint ()))
235+ if err != nil {
236+ return nil , xerrors .WithStackTrace (err )
237+ }
238+
215239 b = & Balancer {
216- driverConfig : driverConfig ,
217- pool : pool ,
218- discoveryClient : internalDiscovery .New (ctx , pool .Get (
219- endpoint .New (driverConfig .Endpoint ()),
220- ), discoveryConfig ),
240+ config : balancerConfig.Config {},
241+ driverConfig : driverConfig ,
242+ pool : pool ,
243+ discoveryClient : internalDiscovery .New (ctx , cc , discoveryConfig ),
221244 localDCDetector : detectLocalDC ,
222245 }
223246
224- if config := driverConfig .Balancer (); config == nil {
225- b .config = balancerConfig.Config {}
226- } else {
247+ if config := driverConfig .Balancer (); config != nil {
227248 b .config = * config
228249 }
229250
@@ -289,10 +310,10 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
289310 defer func () {
290311 if err == nil {
291312 if cc .GetState () == conn .Banned {
292- b .pool . Allow ( ctx , cc )
313+ b .banned . Remove ( cc . Endpoint () )
293314 }
294- } else if xerrors . MustPessimizeEndpoint (err , b .driverConfig .ExcludeGRPCCodesForPessimization ()... ) {
295- b .pool . Ban ( ctx , cc , err )
315+ } else if conn . IsBadConn (err , b .driverConfig .ExcludeGRPCCodesForPessimization ()... ) {
316+ b .banned . Add ( cc . Endpoint () )
296317 }
297318 }()
298319
@@ -319,53 +340,46 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
319340 return nil
320341}
321342
322- func (b * Balancer ) connections () * connectionsState {
323- return b .connectionsState .Load ()
324- }
325-
326343func (b * Balancer ) getConn (ctx context.Context ) (c conn.Conn , err error ) {
327- onDone := trace .DriverOnBalancerChooseEndpoint (
328- b .driverConfig .Trace (), & ctx ,
329- stack .FunctionID ("github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).getConn" ),
344+ var (
345+ onDone = trace .DriverOnBalancerChooseEndpoint (
346+ b .driverConfig .Trace (), & ctx ,
347+ stack .FunctionID ("github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).getConn" ),
348+ )
349+ state = b .cluster .Load ()
330350 )
351+
331352 defer func () {
353+ b .cluster .Store (state )
354+
355+ if b .discoveryRepeater != nil {
356+ b .discoveryRepeater .Force ()
357+ }
358+
332359 if err == nil {
333360 onDone (c .Endpoint (), nil )
334361 } else {
335362 onDone (nil , err )
336363 }
337364 }()
338365
339- if err = ctx .Err (); err != nil {
340- return nil , xerrors .WithStackTrace (err )
341- }
342-
343- var (
344- state = b .connections ()
345- failedCount int
346- )
347-
348- defer func () {
349- if failedCount * 2 > state .PreferredCount () && b .discoveryRepeater != nil {
350- b .discoveryRepeater .Force ()
366+ for attempts := 1 ; ; attempts ++ {
367+ if err = ctx .Err (); err != nil {
368+ return nil , xerrors .WithStackTrace (err )
351369 }
352- }()
353370
354- c , failedCount = state .GetConnection (ctx )
355- if c = = nil {
356- return nil , xerrors .WithStackTrace (
357- fmt .Errorf ("%w: cannot get connection from Balancer after %d attempts" , ErrNoEndpoints , failedCount ),
358- )
359- }
371+ e , err : = state .Next (ctx )
372+ if err ! = nil {
373+ return nil , xerrors .WithStackTrace (
374+ fmt .Errorf ("%w: cannot get connection from Balancer after %d attempts" , cluster . ErrNoEndpoints , attempts ),
375+ )
376+ }
360377
361- return c , nil
362- }
378+ cc , err := b .pool .Get (ctx , e )
379+ if err == nil {
380+ return cc , nil
381+ }
363382
364- func endpointsToConnections (p * conn.Pool , endpoints []endpoint.Endpoint ) []conn.Conn {
365- conns := make ([]conn.Conn , 0 , len (endpoints ))
366- for _ , e := range endpoints {
367- conns = append (conns , p .Get (e ))
383+ b .banned .Add (e )
368384 }
369-
370- return conns
371385}
0 commit comments