9
9
"google.golang.org/grpc"
10
10
11
11
"github.com/ydb-platform/ydb-go-sdk/v3/config"
12
+ "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/cluster"
12
13
balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config"
13
14
"github.com/ydb-platform/ydb-go-sdk/v3/internal/closer"
14
15
"github.com/ydb-platform/ydb-go-sdk/v3/internal/conn"
@@ -26,8 +27,6 @@ import (
26
27
"github.com/ydb-platform/ydb-go-sdk/v3/trace"
27
28
)
28
29
29
- var ErrNoEndpoints = xerrors .Wrap (fmt .Errorf ("no endpoints" ))
30
-
31
30
type discoveryClient interface {
32
31
closer.Closer
33
32
@@ -40,9 +39,12 @@ type Balancer struct {
40
39
pool * conn.Pool
41
40
discoveryClient discoveryClient
42
41
discoveryRepeater repeater.Repeater
43
- localDCDetector func (ctx context.Context , endpoints []endpoint.Endpoint ) (string , error )
44
42
45
- connectionsState atomic.Pointer [connectionsState ]
43
+ cluster atomic.Pointer [cluster.Cluster ]
44
+ conns xsync.Map [endpoint.Endpoint , conn.Conn ]
45
+ banned xsync.Set [endpoint.Endpoint ]
46
+
47
+ localDCDetector func (ctx context.Context , endpoints []endpoint.Endpoint ) (string , error )
46
48
47
49
mu xsync.RWMutex
48
50
onApplyDiscoveredEndpoints []func (ctx context.Context , endpoints []endpoint.Info )
@@ -124,19 +126,49 @@ func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context) (err error) {
124
126
}
125
127
126
128
func (b * Balancer ) applyDiscoveredEndpoints (ctx context.Context , newest []endpoint.Endpoint , localDC string ) {
127
- var (
128
- onDone = trace .DriverOnBalancerUpdate (
129
- b .driverConfig .Trace (), & ctx ,
130
- stack .FunctionID (
131
- "github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyDiscoveredEndpoints" ),
132
- b .config .DetectLocalDC ,
133
- )
134
- previous = b .connections ().All ()
129
+ onDone := trace .DriverOnBalancerUpdate (
130
+ b .driverConfig .Trace (), & ctx ,
131
+ stack .FunctionID (
132
+ "github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyDiscoveredEndpoints" ),
133
+ b .config .DetectLocalDC ,
135
134
)
135
+
136
+ state := cluster .New (newest ,
137
+ cluster .From (b .cluster .Load ()),
138
+ cluster .WithFallback (b .config .AllowFallback ),
139
+ cluster .WithFilter (func (e endpoint.Info ) bool {
140
+ if b .config .Filter == nil {
141
+ return true
142
+ }
143
+
144
+ return b .config .Filter .Allow (balancerConfig.Info {SelfLocation : localDC }, e )
145
+ }),
146
+ )
147
+
148
+ previous := b .cluster .Swap (state )
149
+
150
+ _ , added , dropped := xslices .Diff (previous .All (), newest , func (lhs , rhs endpoint.Endpoint ) int {
151
+ return strings .Compare (lhs .Address (), rhs .Address ())
152
+ })
153
+
154
+ for _ , e := range dropped {
155
+ c , ok := b .conns .Extract (e )
156
+ if ! ok {
157
+ panic ("wrong balancer state" )
158
+ }
159
+ b .pool .Put (ctx , c )
160
+ }
161
+
162
+ for _ , e := range added {
163
+ cc , err := b .pool .Get (ctx , e )
164
+ if err != nil {
165
+ b .banned .Add (e )
166
+ } else {
167
+ b .conns .Set (e , cc )
168
+ }
169
+ }
170
+
136
171
defer func () {
137
- _ , added , dropped := xslices .Diff (previous , newest , func (lhs , rhs endpoint.Endpoint ) int {
138
- return strings .Compare (lhs .Address (), rhs .Address ())
139
- })
140
172
onDone (
141
173
xslices .Transform (newest , func (t endpoint.Endpoint ) trace.EndpointInfo { return t }),
142
174
xslices .Transform (added , func (t endpoint.Endpoint ) trace.EndpointInfo { return t }),
@@ -145,25 +177,13 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, newest []endpoi
145
177
)
146
178
}()
147
179
148
- connections := endpointsToConnections (b .pool , newest )
149
- for _ , c := range connections {
150
- b .pool .Allow (ctx , c )
151
- c .Endpoint ().Touch ()
152
- }
153
-
154
- info := balancerConfig.Info {SelfLocation : localDC }
155
- state := newConnectionsState (connections , b .config .Filter , info , b .config .AllowFallback )
156
-
157
- endpointsInfo := make ([]endpoint.Info , len (newest ))
158
- for i , e := range newest {
159
- endpointsInfo [i ] = e
160
- }
161
-
162
- b .connectionsState .Store (state )
180
+ endpoints := xslices .Transform (newest , func (e endpoint.Endpoint ) endpoint.Info {
181
+ return e
182
+ })
163
183
164
184
b .mu .WithLock (func () {
165
185
for _ , onApplyDiscoveredEndpoints := range b .onApplyDiscoveredEndpoints {
166
- onApplyDiscoveredEndpoints (ctx , endpointsInfo )
186
+ onApplyDiscoveredEndpoints (ctx , endpoints )
167
187
}
168
188
})
169
189
}
@@ -212,18 +232,20 @@ func New(
212
232
onDone (finalErr )
213
233
}()
214
234
235
+ cc , err := pool .Get (ctx , endpoint .New (driverConfig .Endpoint ()))
236
+ if err != nil {
237
+ return nil , xerrors .WithStackTrace (err )
238
+ }
239
+
215
240
b = & Balancer {
216
- driverConfig : driverConfig ,
217
- pool : pool ,
218
- discoveryClient : internalDiscovery .New (ctx , pool .Get (
219
- endpoint .New (driverConfig .Endpoint ()),
220
- ), discoveryConfig ),
241
+ config : balancerConfig.Config {},
242
+ driverConfig : driverConfig ,
243
+ pool : pool ,
244
+ discoveryClient : internalDiscovery .New (ctx , cc , discoveryConfig ),
221
245
localDCDetector : detectLocalDC ,
222
246
}
223
247
224
- if config := driverConfig .Balancer (); config == nil {
225
- b .config = balancerConfig.Config {}
226
- } else {
248
+ if config := driverConfig .Balancer (); config != nil {
227
249
b .config = * config
228
250
}
229
251
@@ -289,10 +311,10 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
289
311
defer func () {
290
312
if err == nil {
291
313
if cc .GetState () == conn .Banned {
292
- b .pool . Allow ( ctx , cc )
314
+ b .banned . Remove ( cc . Endpoint () )
293
315
}
294
316
} else if conn .IsBadConn (err , b .driverConfig .ExcludeGRPCCodesForPessimization ()... ) {
295
- b .pool . Ban ( ctx , cc , err )
317
+ b .banned . Add ( cc . Endpoint () )
296
318
}
297
319
}()
298
320
@@ -319,53 +341,45 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
319
341
return nil
320
342
}
321
343
322
- func (b * Balancer ) connections () * connectionsState {
323
- return b .connectionsState .Load ()
324
- }
325
-
326
- func (b * Balancer ) getConn (ctx context.Context ) (c conn.Conn , err error ) {
344
+ func (b * Balancer ) getConn (ctx context.Context ) (c conn.Conn , finalErr error ) {
327
345
onDone := trace .DriverOnBalancerChooseEndpoint (
328
346
b .driverConfig .Trace (), & ctx ,
329
347
stack .FunctionID ("github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).getConn" ),
330
348
)
349
+
331
350
defer func () {
332
- if err == nil {
351
+ if finalErr == nil {
333
352
onDone (c .Endpoint (), nil )
334
353
} else {
335
- onDone (nil , err )
354
+ if b .cluster .Load ().Availability () < 0.5 && b .discoveryRepeater != nil {
355
+ b .discoveryRepeater .Force ()
356
+ }
357
+
358
+ onDone (nil , finalErr )
336
359
}
337
360
}()
338
361
339
- if err = ctx .Err (); err != nil {
340
- return nil , xerrors .WithStackTrace (err )
341
- }
362
+ for attempts := 1 ; ; attempts ++ {
363
+ if err := ctx .Err (); err != nil {
364
+ return nil , xerrors .WithStackTrace (err )
365
+ }
342
366
343
- var (
344
- state = b .connections ()
345
- failedCount int
346
- )
367
+ state := b .cluster .Load ()
347
368
348
- defer func () {
349
- if failedCount * 2 > state .PreferredCount () && b .discoveryRepeater != nil {
350
- b .discoveryRepeater .Force ()
369
+ e , err := state .Next (ctx )
370
+ if err != nil {
371
+ return nil , xerrors .WithStackTrace (
372
+ fmt .Errorf ("%w: cannot get connection from Balancer after %d attempts" , cluster .ErrNoEndpoints , attempts ),
373
+ )
351
374
}
352
- }()
353
-
354
- c , failedCount = state .GetConnection (ctx )
355
- if c == nil {
356
- return nil , xerrors .WithStackTrace (
357
- fmt .Errorf ("%w: cannot get connection from Balancer after %d attempts" , ErrNoEndpoints , failedCount ),
358
- )
359
- }
360
375
361
- return c , nil
362
- }
376
+ cc , err := b .pool .Get (ctx , e )
377
+ if err == nil {
378
+ return cc , nil
379
+ }
363
380
364
- func endpointsToConnections (p * conn.Pool , endpoints []endpoint.Endpoint ) []conn.Conn {
365
- conns := make ([]conn.Conn , 0 , len (endpoints ))
366
- for _ , e := range endpoints {
367
- conns = append (conns , p .Get (e ))
381
+ if b .cluster .CompareAndSwap (state , cluster .Without (b .cluster .Load (), e )) {
382
+ b .banned .Add (e )
383
+ }
368
384
}
369
-
370
- return conns
371
385
}
0 commit comments