diff --git a/CHANGELOG.md b/CHANGELOG.md index 248bf28d4..658249084 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* Pinged new connections on discovery attempt, closed dropped ones, so `ydb_go_sdk_ydb_driver_conns` metric is correct + ## v3.108.3 * Fixed handling of zero values for DyNumber * Fixed the decimal yql slice bounds out of range diff --git a/internal/balancer/balancer.go b/internal/balancer/balancer.go index 861b66fc4..5e6592b1a 100644 --- a/internal/balancer/balancer.go +++ b/internal/balancer/balancer.go @@ -175,6 +175,34 @@ func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context, cc *grpc.ClientC return nil } +func buildConnectionsState(ctx context.Context, pool interface { + GetIfPresent(endpoint endpoint.Endpoint) conn.Conn + Allow(ctx context.Context, cc conn.Conn) + EndpointsToConnections(endpoints []endpoint.Endpoint) []conn.Conn +}, newest []endpoint.Endpoint, + dropped []endpoint.Endpoint, + config balancerConfig.Config, + selfLocation balancerConfig.Info, +) *connectionsState { + connections := pool.EndpointsToConnections(newest) + for _, c := range connections { + pool.Allow(ctx, c) + c.Endpoint().Touch() + _ = c.Ping(ctx) + } + + state := newConnectionsState(connections, config.Filter, selfLocation, config.AllowFallback) + + for _, e := range dropped { + c := pool.GetIfPresent(e) + if c != nil { + _ = c.Close(ctx) + } + } + + return state +} + func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, newest []endpoint.Endpoint, localDC string) { var ( onDone = trace.DriverOnBalancerUpdate( @@ -186,10 +214,12 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, newest []endpoi ) previous = b.connections().All() ) + + _, added, dropped := xslices.Diff(previous, newest, func(lhs, rhs endpoint.Endpoint) int { + return strings.Compare(lhs.Address(), rhs.Address()) + }) + defer func() { - _, added, dropped := xslices.Diff(previous, newest, func(lhs, rhs endpoint.Endpoint) int { - return strings.Compare(lhs.Address(), rhs.Address()) - }) onDone( xslices.Transform(newest, func(t endpoint.Endpoint) trace.EndpointInfo { return t }), xslices.Transform(added, func(t endpoint.Endpoint) trace.EndpointInfo { return t }), @@ -198,21 +228,8 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, newest []endpoi ) }() - connections := endpointsToConnections(b.pool, newest) - for _, c := range connections { - b.pool.Allow(ctx, c) - c.Endpoint().Touch() - } - info := balancerConfig.Info{SelfLocation: localDC} - state := newConnectionsState(connections, b.balancerConfig.Filter, info, b.balancerConfig.AllowFallback) - - endpointsInfo := make([]endpoint.Info, len(newest)) - for i, e := range newest { - endpointsInfo[i] = e - } - - b.connectionsState.Store(state) + b.connectionsState.Store(buildConnectionsState(ctx, b.pool, newest, dropped, b.balancerConfig, info)) } func (b *Balancer) Close(ctx context.Context) (err error) { @@ -444,12 +461,3 @@ func (b *Balancer) nextConn(ctx context.Context) (c conn.Conn, err error) { return c, nil } - -func endpointsToConnections(p *conn.Pool, endpoints []endpoint.Endpoint) []conn.Conn { - conns := make([]conn.Conn, 0, len(endpoints)) - for _, e := range endpoints { - conns = append(conns, p.Get(e)) - } - - return conns -} diff --git a/internal/balancer/balancer_test.go b/internal/balancer/balancer_test.go new file mode 100644 index 000000000..7305987fd --- /dev/null +++ b/internal/balancer/balancer_test.go @@ -0,0 +1,124 @@ +package balancer + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/mock" +) + +type fakePool struct { + connections map[string]*mock.Conn +} + +func (fp *fakePool) EndpointsToConnections(eps []endpoint.Endpoint) []conn.Conn { + var conns []conn.Conn + for _, ep := range eps { + if c, ok := fp.connections[ep.Address()]; ok { + conns = append(conns, c) + } + } + + return conns +} + +func (fp *fakePool) Allow(_ context.Context, _ conn.Conn) {} + +func (fp *fakePool) GetIfPresent(ep endpoint.Endpoint) conn.Conn { + if c, ok := fp.connections[ep.Address()]; ok { + return c + } + + return nil +} + +func TestBuildConnectionsState(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + newEndpoints []endpoint.Endpoint + oldEndpoints []endpoint.Endpoint + initialConns map[string]*mock.Conn + conf balancerConfig.Config + selfLoc balancerConfig.Info + expectPinged []string + expectClosed []string + }{ + { + name: "single new and old endpoint", + newEndpoints: []endpoint.Endpoint{&mock.Endpoint{AddrField: "127.0.0.1"}}, + oldEndpoints: []endpoint.Endpoint{&mock.Endpoint{AddrField: "127.0.0.2"}}, + initialConns: map[string]*mock.Conn{ + "127.0.0.1": { + AddrField: "127.0.0.1", + State: conn.Online, + }, + "127.0.0.2": { + AddrField: "127.0.0.2", + State: conn.Offline, + }, + }, + conf: balancerConfig.Config{ + AllowFallback: true, + DetectNearestDC: true, + }, + selfLoc: balancerConfig.Info{SelfLocation: "local"}, + expectPinged: []string{"127.0.0.1"}, + expectClosed: []string{"127.0.0.2"}, + }, + { + newEndpoints: []endpoint.Endpoint{&mock.Endpoint{AddrField: "a1"}, &mock.Endpoint{AddrField: "a2"}}, + oldEndpoints: []endpoint.Endpoint{&mock.Endpoint{AddrField: "a3"}}, + initialConns: map[string]*mock.Conn{ + "a1": { + AddrField: "a1", + LocationField: "local", + State: conn.Offline, + }, + "a2": { + AddrField: "a2", + State: conn.Offline, + }, + "a3": { + AddrField: "a3", + State: conn.Online, + }, + }, + conf: balancerConfig.Config{ + AllowFallback: true, + DetectNearestDC: true, + }, + selfLoc: balancerConfig.Info{SelfLocation: "local"}, + expectPinged: []string{"a1", "a2"}, + expectClosed: []string{"a3"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fp := &fakePool{connections: make(map[string]*mock.Conn)} + for addr, c := range tt.initialConns { + fp.connections[addr] = c + } + + state := buildConnectionsState(ctx, fp, tt.newEndpoints, tt.oldEndpoints, tt.conf, tt.selfLoc) + assert.NotNil(t, state) + for _, addr := range tt.expectPinged { + c := fp.connections[addr] + assert.True(t, c.Pinged.Load(), "connection %s should be pinged", addr) + assert.True(t, c.State == conn.Online || c.PingErr != nil) + } + for _, addr := range tt.expectClosed { + c := fp.connections[addr] + assert.True(t, c.Closed.Load(), "connection %s should be closed", addr) + assert.True(t, c.State == conn.Offline, "connection %s should be offline", addr) + } + }) + } +} diff --git a/internal/conn/conn.go b/internal/conn/conn.go index 8b049c18d..9869f283e 100644 --- a/internal/conn/conn.go +++ b/internal/conn/conn.go @@ -13,6 +13,7 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/stats" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/closer" "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" "github.com/ydb-platform/ydb-go-sdk/v3/internal/meta" "github.com/ydb-platform/ydb-go-sdk/v3/internal/operation" @@ -36,6 +37,7 @@ var ( type Conn interface { grpc.ClientConnInterface + closer.Closer Endpoint() endpoint.Endpoint diff --git a/internal/conn/pool.go b/internal/conn/pool.go index 18974bc5e..b1087555e 100644 --- a/internal/conn/pool.go +++ b/internal/conn/pool.go @@ -40,6 +40,20 @@ func (p *Pool) GrpcDialOptions() []grpc.DialOption { return p.dialOptions } +func (p *Pool) GetIfPresent(endpoint endpoint.Endpoint) Conn { + var ( + address = endpoint.Address() + cc *conn + has bool + ) + + if cc, has = p.conns.Get(address); has { + return cc + } + + return nil +} + func (p *Pool) Get(endpoint endpoint.Endpoint) Conn { var ( address = endpoint.Address() @@ -252,3 +266,12 @@ func NewPool(ctx context.Context, config Config) *Pool { return p } + +func (p *Pool) EndpointsToConnections(endpoints []endpoint.Endpoint) []Conn { + conns := make([]Conn, 0, len(endpoints)) + for _, e := range endpoints { + conns = append(conns, p.Get(e)) + } + + return conns +} diff --git a/internal/mock/conn.go b/internal/mock/conn.go index 7a9fb99f7..cc2c5afd9 100644 --- a/internal/mock/conn.go +++ b/internal/mock/conn.go @@ -2,6 +2,7 @@ package mock import ( "context" + "sync/atomic" "time" "google.golang.org/grpc" @@ -17,6 +18,8 @@ type Conn struct { NodeIDField uint32 State conn.State LocalDCField bool + Pinged atomic.Bool + Closed atomic.Bool } func (c *Conn) Invoke( @@ -53,7 +56,19 @@ func (c *Conn) Park(ctx context.Context) (err error) { panic("not implemented in mock") } +func (c *Conn) Close(ctx context.Context) error { + c.Closed.Store(true) + c.SetState(ctx, conn.Offline) + + return nil +} + func (c *Conn) Ping(ctx context.Context) error { + c.Pinged.Store(true) + if c.PingErr == nil { + c.SetState(ctx, conn.Online) + } + return c.PingErr } @@ -116,7 +131,7 @@ func (e *Endpoint) LoadFactor() float32 { } func (e *Endpoint) OverrideHost() string { - panic("not implemented in mock") + return "" } func (e *Endpoint) String() string {