From 2bbeb9c97141c0f16d99d407b97c2d2366adfd38 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Wed, 19 Feb 2025 18:09:08 +0800 Subject: [PATCH] client: support dynamic start/stop of the router client (#9082) ref tikv/pd#8690 Support dynamic start and stop of the router client. Signed-off-by: JmPotato Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- client/client.go | 12 +- client/inner_client.go | 60 ++- client/opt/option.go | 48 ++- client/opt/option_test.go | 123 ++++-- tests/integrations/client/client_test.go | 234 ----------- .../integrations/client/router_client_test.go | 368 ++++++++++++++++++ 6 files changed, 555 insertions(+), 290 deletions(-) create mode 100644 tests/integrations/client/router_client_test.go diff --git a/client/client.go b/client/client.go index e5f6442780b..998519908fe 100644 --- a/client/client.go +++ b/client/client.go @@ -456,6 +456,12 @@ func (c *client) UpdateOption(option opt.DynamicOption, value any) error { return errors.New("[pd] invalid value type for TSOClientRPCConcurrency option, it should be int") } c.inner.option.SetTSOClientRPCConcurrency(value) + case opt.EnableRouterClient: + enable, ok := value.(bool) + if !ok { + return errors.New("[pd] invalid value type for EnableRouterClient option, it should be bool") + } + c.inner.option.SetEnableRouterClient(enable) default: return errors.New("[pd] unsupported client option") } @@ -569,12 +575,6 @@ func (c *client) GetMinTS(ctx context.Context) (physical int64, logical int64, e return minTS.Physical, minTS.Logical, nil } -// EnableRouterClient enables the router client. -// This is only for test currently. -func (c *client) EnableRouterClient() { - c.inner.initRouterClient() -} - func (c *client) getRouterClient() *router.Cli { c.inner.RLock() defer c.inner.RUnlock() diff --git a/client/inner_client.go b/client/inner_client.go index 181ee2c9d52..fbc8227ccce 100644 --- a/client/inner_client.go +++ b/client/inner_client.go @@ -71,16 +71,69 @@ func (c *innerClient) init(updateKeyspaceIDCb sd.UpdateKeyspaceIDFunc) error { return err } + // Check if the router client has been enabled. + if c.option.GetEnableRouterClient() { + c.enableRouterClient() + } + c.wg.Add(1) + go c.routerClientInitializer() + return nil } -func (c *innerClient) initRouterClient() { +func (c *innerClient) routerClientInitializer() { + log.Info("[pd] start router client initializer") + defer c.wg.Done() + for { + select { + case <-c.ctx.Done(): + log.Info("[pd] exit router client initializer") + return + case <-c.option.EnableRouterClientCh: + if c.option.GetEnableRouterClient() { + log.Info("[pd] notified to enable the router client") + c.enableRouterClient() + } else { + log.Info("[pd] notified to disable the router client") + c.disableRouterClient() + } + } + } +} + +func (c *innerClient) enableRouterClient() { + // Check if the router client has been enabled. + c.RLock() + if c.routerClient != nil { + c.RUnlock() + return + } + c.RUnlock() + // Create a new router client first before acquiring the lock. + routerClient := router.NewClient(c.ctx, c.serviceDiscovery, c.option) c.Lock() - defer c.Unlock() + // Double check if the router client has been enabled. if c.routerClient != nil { + // Release the lock and close the router client. + c.Unlock() + routerClient.Close() return } - c.routerClient = router.NewClient(c.ctx, c.serviceDiscovery, c.option) + c.routerClient = routerClient + c.Unlock() +} + +func (c *innerClient) disableRouterClient() { + c.Lock() + if c.routerClient == nil { + c.Unlock() + return + } + routerClient := c.routerClient + c.routerClient = nil + c.Unlock() + // Close the router client after the lock is released. + routerClient.Close() } func (c *innerClient) setServiceMode(newMode pdpb.ServiceMode) { @@ -214,6 +267,7 @@ func (c *innerClient) setup() error { // Create dispatchers c.createTokenDispatcher() + return nil } diff --git a/client/opt/option.go b/client/opt/option.go index 2aa9be8ae7f..2790c93b003 100644 --- a/client/opt/option.go +++ b/client/opt/option.go @@ -33,6 +33,7 @@ const ( defaultEnableTSOFollowerProxy = false defaultEnableFollowerHandle = false defaultTSOClientRPCConcurrency = 1 + defaultEnableRouterClient = false ) // DynamicOption is used to distinguish the dynamic option type. @@ -49,6 +50,9 @@ const ( EnableFollowerHandle // TSOClientRPCConcurrency controls the amount of ongoing TSO RPC requests at the same time in a single TSO client. TSOClientRPCConcurrency + // EnableRouterClient is the router client option. + // It is stored as bool. + EnableRouterClient dynamicOptionCount ) @@ -70,6 +74,7 @@ type Option struct { dynamicOptions [dynamicOptionCount]atomic.Value EnableTSOFollowerProxyCh chan struct{} + EnableRouterClientCh chan struct{} } // NewOption creates a new PD client option with the default values set. @@ -78,6 +83,7 @@ func NewOption() *Option { Timeout: defaultPDTimeout, MaxRetryTimes: maxInitClusterRetries, EnableTSOFollowerProxyCh: make(chan struct{}, 1), + EnableRouterClientCh: make(chan struct{}, 1), InitMetrics: true, } @@ -85,6 +91,7 @@ func NewOption() *Option { co.dynamicOptions[EnableTSOFollowerProxy].Store(defaultEnableTSOFollowerProxy) co.dynamicOptions[EnableFollowerHandle].Store(defaultEnableFollowerHandle) co.dynamicOptions[TSOClientRPCConcurrency].Store(defaultTSOClientRPCConcurrency) + co.dynamicOptions[EnableRouterClient].Store(defaultEnableRouterClient) return co } @@ -94,19 +101,13 @@ func (o *Option) SetMaxTSOBatchWaitInterval(interval time.Duration) error { if interval < 0 || interval > 10*time.Millisecond { return errors.New("[pd] invalid max TSO batch wait interval, should be between 0 and 10ms") } - old := o.GetMaxTSOBatchWaitInterval() - if interval != old { - o.dynamicOptions[MaxTSOBatchWaitInterval].Store(interval) - } + o.dynamicOptions[MaxTSOBatchWaitInterval].CompareAndSwap(o.GetMaxTSOBatchWaitInterval(), interval) return nil } // SetEnableFollowerHandle set the Follower Handle option. func (o *Option) SetEnableFollowerHandle(enable bool) { - old := o.GetEnableFollowerHandle() - if enable != old { - o.dynamicOptions[EnableFollowerHandle].Store(enable) - } + o.dynamicOptions[EnableFollowerHandle].CompareAndSwap(!enable, enable) } // GetEnableFollowerHandle gets the Follower Handle enable option. @@ -121,9 +122,7 @@ func (o *Option) GetMaxTSOBatchWaitInterval() time.Duration { // SetEnableTSOFollowerProxy sets the TSO Follower Proxy option. func (o *Option) SetEnableTSOFollowerProxy(enable bool) { - old := o.GetEnableTSOFollowerProxy() - if enable != old { - o.dynamicOptions[EnableTSOFollowerProxy].Store(enable) + if o.dynamicOptions[EnableTSOFollowerProxy].CompareAndSwap(!enable, enable) { select { case o.EnableTSOFollowerProxyCh <- struct{}{}: default: @@ -138,10 +137,7 @@ func (o *Option) GetEnableTSOFollowerProxy() bool { // SetTSOClientRPCConcurrency sets the TSO client RPC concurrency option. func (o *Option) SetTSOClientRPCConcurrency(value int) { - old := o.GetTSOClientRPCConcurrency() - if value != old { - o.dynamicOptions[TSOClientRPCConcurrency].Store(value) - } + o.dynamicOptions[TSOClientRPCConcurrency].CompareAndSwap(o.GetTSOClientRPCConcurrency(), value) } // GetTSOClientRPCConcurrency gets the TSO client RPC concurrency option. @@ -149,6 +145,21 @@ func (o *Option) GetTSOClientRPCConcurrency() int { return o.dynamicOptions[TSOClientRPCConcurrency].Load().(int) } +// SetEnableRouterClient sets the router client option. +func (o *Option) SetEnableRouterClient(enable bool) { + if o.dynamicOptions[EnableRouterClient].CompareAndSwap(!enable, enable) { + select { + case o.EnableRouterClientCh <- struct{}{}: + default: + } + } +} + +// GetEnableRouterClient gets the router client option. +func (o *Option) GetEnableRouterClient() bool { + return o.dynamicOptions[EnableRouterClient].Load().(bool) +} + // ClientOption configures client. type ClientOption func(*Option) @@ -210,6 +221,13 @@ func WithBackoffer(bo *retry.Backoffer) ClientOption { } } +// WithEnableRouterClient configures the client with router client option. +func WithEnableRouterClient(enable bool) ClientOption { + return func(op *Option) { + op.SetEnableRouterClient(enable) + } +} + // GetStoreOp represents available options when getting stores. type GetStoreOp struct { ExcludeTombstone bool diff --git a/client/opt/option_test.go b/client/opt/option_test.go index fa0decbb3a1..a0ab5a7effc 100644 --- a/client/opt/option_test.go +++ b/client/opt/option_test.go @@ -26,42 +26,101 @@ import ( func TestDynamicOptionChange(t *testing.T) { re := require.New(t) o := NewOption() - // Check the default value setting. - re.Equal(defaultMaxTSOBatchWaitInterval, o.GetMaxTSOBatchWaitInterval()) - re.Equal(defaultEnableTSOFollowerProxy, o.GetEnableTSOFollowerProxy()) - re.Equal(defaultEnableFollowerHandle, o.GetEnableFollowerHandle()) - - // Check the invalid value setting. - re.Error(o.SetMaxTSOBatchWaitInterval(time.Second)) - re.Equal(defaultMaxTSOBatchWaitInterval, o.GetMaxTSOBatchWaitInterval()) - expectInterval := time.Millisecond - o.SetMaxTSOBatchWaitInterval(expectInterval) - re.Equal(expectInterval, o.GetMaxTSOBatchWaitInterval()) - expectInterval = time.Duration(float64(time.Millisecond) * 0.5) - o.SetMaxTSOBatchWaitInterval(expectInterval) - re.Equal(expectInterval, o.GetMaxTSOBatchWaitInterval()) - expectInterval = time.Duration(float64(time.Millisecond) * 1.5) - o.SetMaxTSOBatchWaitInterval(expectInterval) - re.Equal(expectInterval, o.GetMaxTSOBatchWaitInterval()) - - expectBool := true - o.SetEnableTSOFollowerProxy(expectBool) - // Check the value changing notification. - testutil.Eventually(re, func() bool { - <-o.EnableTSOFollowerProxyCh - return true - }) - re.Equal(expectBool, o.GetEnableTSOFollowerProxy()) - // Check whether any data will be sent to the channel. - // It will panic if the test fails. - close(o.EnableTSOFollowerProxyCh) - // Setting the same value should not notify the channel. + + // Test default values. + re.Equal(defaultMaxTSOBatchWaitInterval, o.GetMaxTSOBatchWaitInterval(), "default max TSO batch wait interval") + re.Equal(defaultEnableTSOFollowerProxy, o.GetEnableTSOFollowerProxy(), "default enable TSO follower proxy") + re.Equal(defaultEnableFollowerHandle, o.GetEnableFollowerHandle(), "default enable follower handle") + re.Equal(defaultTSOClientRPCConcurrency, o.GetTSOClientRPCConcurrency(), "default TSO client RPC concurrency") + re.Equal(defaultEnableRouterClient, o.GetEnableRouterClient(), "default enable router client") + + // Test invalid setting. + err := o.SetMaxTSOBatchWaitInterval(time.Second) + re.Error(err, "expect error for invalid high interval") + // Value remains unchanged. + re.Equal(defaultMaxTSOBatchWaitInterval, o.GetMaxTSOBatchWaitInterval(), "max TSO batch wait interval should not change to an invalid value") + + // Define a list of valid intervals. + validIntervals := []time.Duration{ + time.Millisecond, + time.Duration(float64(time.Millisecond) * 0.5), + time.Duration(float64(time.Millisecond) * 1.5), + 10 * time.Millisecond, + 0, + } + for _, interval := range validIntervals { + // Use a subtest for each valid interval. + err := o.SetMaxTSOBatchWaitInterval(interval) + re.NoError(err, "expected interval %v to be set without error", interval) + re.Equal(interval, o.GetMaxTSOBatchWaitInterval(), "max TSO batch wait interval should be updated to %v", interval) + } + + clearChannel(o.EnableTSOFollowerProxyCh) + + // Testing that the setting is effective and a notification is sent. + var expectBool bool + for _, expectBool = range []bool{true, false} { + o.SetEnableTSOFollowerProxy(expectBool) + testutil.Eventually(re, func() bool { + select { + case <-o.EnableTSOFollowerProxyCh: + default: + return false + } + return o.GetEnableTSOFollowerProxy() == expectBool + }) + } + + // Testing that setting the same value should not trigger a notification. o.SetEnableTSOFollowerProxy(expectBool) + ensureNoNotification(t, o.EnableTSOFollowerProxyCh) + // This option does not use a notification channel. expectBool = true o.SetEnableFollowerHandle(expectBool) - re.Equal(expectBool, o.GetEnableFollowerHandle()) + re.Equal(expectBool, o.GetEnableFollowerHandle(), "EnableFollowerHandle should be set to true") expectBool = false o.SetEnableFollowerHandle(expectBool) - re.Equal(expectBool, o.GetEnableFollowerHandle()) + re.Equal(expectBool, o.GetEnableFollowerHandle(), "EnableFollowerHandle should be set to false") + + expectInt := 10 + o.SetTSOClientRPCConcurrency(expectInt) + re.Equal(expectInt, o.GetTSOClientRPCConcurrency(), "TSOClientRPCConcurrency should update accordingly") + + clearChannel(o.EnableRouterClientCh) + + // Testing that the setting is effective and a notification is sent. + for _, expectBool = range []bool{true, false} { + o.SetEnableRouterClient(expectBool) + testutil.Eventually(re, func() bool { + select { + case <-o.EnableRouterClientCh: + default: + return false + } + return o.GetEnableRouterClient() == expectBool + }) + } + + // Testing that setting the same value should not trigger a notification. + o.SetEnableRouterClient(expectBool) + ensureNoNotification(t, o.EnableRouterClientCh) +} + +// clearChannel drains any pending events from the channel. +func clearChannel(ch chan struct{}) { + select { + case <-ch: + default: + } +} + +// ensureNoNotification checks that no notification is sent on the channel within a short timeout. +func ensureNoNotification(t *testing.T, ch chan struct{}) { + select { + case v := <-ch: + t.Fatalf("unexpected notification received: %v", v) + case <-time.After(100 * time.Millisecond): + // No notification received as expected. + } } diff --git a/tests/integrations/client/client_test.go b/tests/integrations/client/client_test.go index bacf2a72618..e4ee2be03e2 100644 --- a/tests/integrations/client/client_test.go +++ b/tests/integrations/client/client_test.go @@ -20,10 +20,8 @@ import ( "encoding/json" "fmt" "math" - "math/rand" "os" "path" - "reflect" "sort" "strconv" "strings" @@ -1110,118 +1108,6 @@ func (suite *clientTestSuite) SetupTest() { suite.grpcSvr.DirectlyGetRaftCluster().ResetRegionCache() } -func (suite *clientTestSuite) TestGetRegion() { - re := suite.Require() - regionID := regionIDAllocator.alloc() - region := &metapb.Region{ - Id: regionID, - RegionEpoch: &metapb.RegionEpoch{ - ConfVer: 1, - Version: 1, - }, - Peers: peers, - } - req := &pdpb.RegionHeartbeatRequest{ - Header: newHeader(), - Region: region, - Leader: peers[0], - } - err := suite.regionHeartbeat.Send(req) - re.NoError(err) - testutil.Eventually(re, func() bool { - r, err := suite.client.GetRegion(context.Background(), []byte("a")) - re.NoError(err) - if r == nil { - return false - } - return reflect.DeepEqual(region, r.Meta) && - reflect.DeepEqual(peers[0], r.Leader) && - r.Buckets == nil - }) - breq := &pdpb.ReportBucketsRequest{ - Header: newHeader(), - Buckets: &metapb.Buckets{ - RegionId: regionID, - Version: 1, - Keys: [][]byte{[]byte("a"), []byte("z")}, - PeriodInMs: 2000, - Stats: &metapb.BucketStats{ - ReadBytes: []uint64{1}, - ReadKeys: []uint64{1}, - ReadQps: []uint64{1}, - WriteBytes: []uint64{1}, - WriteKeys: []uint64{1}, - WriteQps: []uint64{1}, - }, - }, - } - re.NoError(suite.reportBucket.Send(breq)) - testutil.Eventually(re, func() bool { - r, err := suite.client.GetRegion(context.Background(), []byte("a"), opt.WithBuckets()) - re.NoError(err) - if r == nil { - return false - } - return r.Buckets != nil - }) - suite.srv.GetRaftCluster().GetOpts().(*config.PersistOptions).SetRegionBucketEnabled(false) - - testutil.Eventually(re, func() bool { - r, err := suite.client.GetRegion(context.Background(), []byte("a"), opt.WithBuckets()) - re.NoError(err) - if r == nil { - return false - } - return r.Buckets == nil - }) - suite.srv.GetRaftCluster().GetOpts().(*config.PersistOptions).SetRegionBucketEnabled(true) - - re.NoError(failpoint.Enable("github.com/tikv/pd/server/grpcClientClosed", `return(true)`)) - re.NoError(failpoint.Enable("github.com/tikv/pd/server/useForwardRequest", `return(true)`)) - re.NoError(suite.reportBucket.Send(breq)) - re.Error(suite.reportBucket.RecvMsg(breq)) - re.NoError(failpoint.Disable("github.com/tikv/pd/server/grpcClientClosed")) - re.NoError(failpoint.Disable("github.com/tikv/pd/server/useForwardRequest")) -} - -func (suite *clientTestSuite) TestGetPrevRegion() { - re := suite.Require() - regionLen := 10 - regions := make([]*metapb.Region, 0, regionLen) - for i := range regionLen { - regionID := regionIDAllocator.alloc() - r := &metapb.Region{ - Id: regionID, - RegionEpoch: &metapb.RegionEpoch{ - ConfVer: 1, - Version: 1, - }, - StartKey: []byte{byte(i)}, - EndKey: []byte{byte(i + 1)}, - Peers: peers, - } - regions = append(regions, r) - req := &pdpb.RegionHeartbeatRequest{ - Header: newHeader(), - Region: r, - Leader: peers[0], - } - err := suite.regionHeartbeat.Send(req) - re.NoError(err) - } - for i := range 20 { - testutil.Eventually(re, func() bool { - r, err := suite.client.GetPrevRegion(context.Background(), []byte{byte(i)}) - re.NoError(err) - if i > 0 && i < regionLen { - return reflect.DeepEqual(peers[0], r.Leader) && - reflect.DeepEqual(regions[i-1], r.Meta) - } - return r == nil - }) - } -} - func (suite *clientTestSuite) TestScanRegions() { re := suite.Require() regionLen := 10 @@ -1299,126 +1185,6 @@ func (suite *clientTestSuite) TestScanRegions() { check([]byte{1}, []byte{6}, 2, regions[1:3]) } -func (suite *clientTestSuite) TestGetRegionByID() { - re := suite.Require() - regionID := regionIDAllocator.alloc() - region := &metapb.Region{ - Id: regionID, - RegionEpoch: &metapb.RegionEpoch{ - ConfVer: 1, - Version: 1, - }, - Peers: peers, - } - req := &pdpb.RegionHeartbeatRequest{ - Header: newHeader(), - Region: region, - Leader: peers[0], - } - err := suite.regionHeartbeat.Send(req) - re.NoError(err) - - testutil.Eventually(re, func() bool { - r, err := suite.client.GetRegionByID(context.Background(), regionID) - re.NoError(err) - if r == nil { - return false - } - return reflect.DeepEqual(region, r.Meta) && - reflect.DeepEqual(peers[0], r.Leader) - }) - - // test WithCallerComponent - testutil.Eventually(re, func() bool { - r, err := suite.client. - WithCallerComponent(caller.GetComponent(0)). - GetRegionByID(context.Background(), regionID) - re.NoError(err) - if r == nil { - return false - } - return reflect.DeepEqual(region, r.Meta) && - reflect.DeepEqual(peers[0], r.Leader) - }) -} - -func (suite *clientTestSuite) TestGetRegionConcurrently() { - suite.client.(interface{ EnableRouterClient() }).EnableRouterClient() - - re := suite.Require() - ctx, cancel := context.WithCancel(suite.ctx) - defer cancel() - - regions := make([]*metapb.Region, 0, 2) - for i := range 2 { - regionID := regionIDAllocator.alloc() - region := &metapb.Region{ - Id: regionID, - RegionEpoch: &metapb.RegionEpoch{ - ConfVer: 1, - Version: 1, - }, - StartKey: []byte{byte(i)}, - EndKey: []byte{byte(i + 1)}, - Peers: peers, - } - re.NoError(suite.regionHeartbeat.Send(&pdpb.RegionHeartbeatRequest{ - Header: newHeader(), - Region: region, - Leader: peers[0], - })) - regions = append(regions, region) - } - - const concurrency = 1000 - - wg := sync.WaitGroup{} - wg.Add(concurrency) - for range concurrency { - go func() { - defer wg.Done() - switch rand.Intn(3) { - case 0: - region := regions[0] - testutil.Eventually(re, func() bool { - r, err := suite.client.GetRegion(ctx, region.GetStartKey()) - re.NoError(err) - if r == nil { - return false - } - return reflect.DeepEqual(region, r.Meta) && - reflect.DeepEqual(peers[0], r.Leader) && - r.Buckets == nil - }) - case 1: - testutil.Eventually(re, func() bool { - r, err := suite.client.GetPrevRegion(ctx, regions[1].GetStartKey()) - re.NoError(err) - if r == nil { - return false - } - return reflect.DeepEqual(regions[0], r.Meta) && - reflect.DeepEqual(peers[0], r.Leader) && - r.Buckets == nil - }) - case 2: - region := regions[0] - testutil.Eventually(re, func() bool { - r, err := suite.client.GetRegionByID(ctx, region.GetId()) - re.NoError(err) - if r == nil { - return false - } - return reflect.DeepEqual(region, r.Meta) && - reflect.DeepEqual(peers[0], r.Leader) && - r.Buckets == nil - }) - } - }() - } - wg.Wait() -} - func (suite *clientTestSuite) TestGetStore() { re := suite.Require() cluster := suite.srv.GetRaftCluster() diff --git a/tests/integrations/client/router_client_test.go b/tests/integrations/client/router_client_test.go new file mode 100644 index 00000000000..06100f3d4f2 --- /dev/null +++ b/tests/integrations/client/router_client_test.go @@ -0,0 +1,368 @@ +// Copyright 2025 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client_test + +import ( + "context" + "math/rand" + "reflect" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/kvproto/pkg/pdpb" + + pd "github.com/tikv/pd/client" + "github.com/tikv/pd/client/opt" + "github.com/tikv/pd/client/pkg/caller" + "github.com/tikv/pd/pkg/utils/assertutil" + "github.com/tikv/pd/pkg/utils/testutil" + "github.com/tikv/pd/server" + "github.com/tikv/pd/server/config" +) + +func TestRouterClientEnabledSuite(t *testing.T) { + suite.Run(t, &routerClientSuite{routerClientEnabled: true}) +} + +func TestRouterClientDisabledSuite(t *testing.T) { + suite.Run(t, &routerClientSuite{routerClientEnabled: false}) +} + +type routerClientSuite struct { + suite.Suite + cleanup testutil.CleanupFunc + ctx context.Context + clean context.CancelFunc + srv *server.Server + client pd.Client + grpcPDClient pdpb.PDClient + regionHeartbeat pdpb.PD_RegionHeartbeatClient + reportBucket pdpb.PD_ReportBucketsClient + + routerClientEnabled bool +} + +func (suite *routerClientSuite) SetupSuite() { + var err error + re := suite.Require() + suite.srv, suite.cleanup, err = server.NewTestServer(re, assertutil.CheckerWithNilAssert(re)) + re.NoError(err) + suite.grpcPDClient = testutil.MustNewGrpcClient(re, suite.srv.GetAddr()) + + server.MustWaitLeader(re, []*server.Server{suite.srv}) + bootstrapServer(re, newHeader(), suite.grpcPDClient) + + suite.ctx, suite.clean = context.WithCancel(context.Background()) + suite.client = setupCli(suite.ctx, re, suite.srv.GetEndpoints(), opt.WithEnableRouterClient(suite.routerClientEnabled)) + + suite.regionHeartbeat, err = suite.grpcPDClient.RegionHeartbeat(suite.ctx) + re.NoError(err) + suite.reportBucket, err = suite.grpcPDClient.ReportBuckets(suite.ctx) + re.NoError(err) + cluster := suite.srv.GetRaftCluster() + re.NotNil(cluster) + cluster.GetOpts().(*config.PersistOptions).SetRegionBucketEnabled(true) +} + +// TearDownSuite cleans up the test cluster and client. +func (suite *routerClientSuite) TearDownSuite() { + suite.client.Close() + suite.clean() + suite.cleanup() +} + +func (suite *routerClientSuite) TestGetRegion() { + re := suite.Require() + regionID := regionIDAllocator.alloc() + region := &metapb.Region{ + Id: regionID, + RegionEpoch: &metapb.RegionEpoch{ + ConfVer: 1, + Version: 1, + }, + Peers: peers, + } + req := &pdpb.RegionHeartbeatRequest{ + Header: newHeader(), + Region: region, + Leader: peers[0], + } + err := suite.regionHeartbeat.Send(req) + re.NoError(err) + testutil.Eventually(re, func() bool { + r, err := suite.client.GetRegion(context.Background(), []byte("a")) + re.NoError(err) + if r == nil { + return false + } + return reflect.DeepEqual(region, r.Meta) && + reflect.DeepEqual(peers[0], r.Leader) && + r.Buckets == nil + }) + breq := &pdpb.ReportBucketsRequest{ + Header: newHeader(), + Buckets: &metapb.Buckets{ + RegionId: regionID, + Version: 1, + Keys: [][]byte{[]byte("a"), []byte("z")}, + PeriodInMs: 2000, + Stats: &metapb.BucketStats{ + ReadBytes: []uint64{1}, + ReadKeys: []uint64{1}, + ReadQps: []uint64{1}, + WriteBytes: []uint64{1}, + WriteKeys: []uint64{1}, + WriteQps: []uint64{1}, + }, + }, + } + re.NoError(suite.reportBucket.Send(breq)) + testutil.Eventually(re, func() bool { + r, err := suite.client.GetRegion(context.Background(), []byte("a"), opt.WithBuckets()) + re.NoError(err) + if r == nil { + return false + } + return r.Buckets != nil + }) + suite.srv.GetRaftCluster().GetOpts().(*config.PersistOptions).SetRegionBucketEnabled(false) + + testutil.Eventually(re, func() bool { + r, err := suite.client.GetRegion(context.Background(), []byte("a"), opt.WithBuckets()) + re.NoError(err) + if r == nil { + return false + } + return r.Buckets == nil + }) + suite.srv.GetRaftCluster().GetOpts().(*config.PersistOptions).SetRegionBucketEnabled(true) + + re.NoError(failpoint.Enable("github.com/tikv/pd/server/grpcClientClosed", `return(true)`)) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/useForwardRequest", `return(true)`)) + re.NoError(suite.reportBucket.Send(breq)) + re.Error(suite.reportBucket.RecvMsg(breq)) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/grpcClientClosed")) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/useForwardRequest")) +} + +func (suite *routerClientSuite) TestGetPrevRegion() { + re := suite.Require() + regionLen := 10 + regions := make([]*metapb.Region, 0, regionLen) + for i := range regionLen { + regionID := regionIDAllocator.alloc() + r := &metapb.Region{ + Id: regionID, + RegionEpoch: &metapb.RegionEpoch{ + ConfVer: 1, + Version: 1, + }, + StartKey: []byte{byte(i)}, + EndKey: []byte{byte(i + 1)}, + Peers: peers, + } + regions = append(regions, r) + req := &pdpb.RegionHeartbeatRequest{ + Header: newHeader(), + Region: r, + Leader: peers[0], + } + err := suite.regionHeartbeat.Send(req) + re.NoError(err) + } + for i := range 20 { + testutil.Eventually(re, func() bool { + r, err := suite.client.GetPrevRegion(context.Background(), []byte{byte(i)}) + re.NoError(err) + if i > 0 && i < regionLen { + return reflect.DeepEqual(peers[0], r.Leader) && + reflect.DeepEqual(regions[i-1], r.Meta) + } + return r == nil + }) + } +} + +func (suite *routerClientSuite) TestGetRegionByID() { + re := suite.Require() + regionID := regionIDAllocator.alloc() + region := &metapb.Region{ + Id: regionID, + RegionEpoch: &metapb.RegionEpoch{ + ConfVer: 1, + Version: 1, + }, + Peers: peers, + } + req := &pdpb.RegionHeartbeatRequest{ + Header: newHeader(), + Region: region, + Leader: peers[0], + } + err := suite.regionHeartbeat.Send(req) + re.NoError(err) + + testutil.Eventually(re, func() bool { + r, err := suite.client.GetRegionByID(context.Background(), regionID) + re.NoError(err) + if r == nil { + return false + } + return reflect.DeepEqual(region, r.Meta) && + reflect.DeepEqual(peers[0], r.Leader) + }) + + // test WithCallerComponent + testutil.Eventually(re, func() bool { + r, err := suite.client. + WithCallerComponent(caller.GetComponent(0)). + GetRegionByID(context.Background(), regionID) + re.NoError(err) + if r == nil { + return false + } + return reflect.DeepEqual(region, r.Meta) && + reflect.DeepEqual(peers[0], r.Leader) + }) +} + +func (suite *routerClientSuite) TestGetRegionConcurrently() { + re := suite.Require() + ctx, cancel := context.WithCancel(suite.ctx) + defer cancel() + + wg := sync.WaitGroup{} + suite.dispatchConcurrentRequests(ctx, re, &wg) + wg.Wait() +} + +func (suite *routerClientSuite) dispatchConcurrentRequests(ctx context.Context, re *require.Assertions, wg *sync.WaitGroup) { + regions := make([]*metapb.Region, 0, 2) + for i := range 2 { + regionID := regionIDAllocator.alloc() + region := &metapb.Region{ + Id: regionID, + RegionEpoch: &metapb.RegionEpoch{ + ConfVer: 1, + Version: 1, + }, + StartKey: []byte{byte(i)}, + EndKey: []byte{byte(i + 1)}, + Peers: peers, + } + re.NoError(suite.regionHeartbeat.Send(&pdpb.RegionHeartbeatRequest{ + Header: newHeader(), + Region: region, + Leader: peers[0], + })) + regions = append(regions, region) + } + + const concurrency = 1000 + + wg.Add(concurrency) + for range concurrency { + go func() { + defer wg.Done() + // Randomly sleep to avoid the concurrent requests to be dispatched at the same time. + seed := rand.Intn(100) + time.Sleep(time.Duration(seed) * time.Millisecond) + switch seed % 3 { + case 0: + region := regions[0] + testutil.Eventually(re, func() bool { + r, err := suite.client.GetRegion(ctx, region.GetStartKey()) + if err != nil { + re.ErrorContains(err, context.Canceled.Error()) + } + if r == nil { + return false + } + return reflect.DeepEqual(region, r.Meta) && + reflect.DeepEqual(peers[0], r.Leader) && + r.Buckets == nil + }) + case 1: + testutil.Eventually(re, func() bool { + r, err := suite.client.GetPrevRegion(ctx, regions[1].GetStartKey()) + if err != nil { + re.ErrorContains(err, context.Canceled.Error()) + } + if r == nil { + return false + } + return reflect.DeepEqual(regions[0], r.Meta) && + reflect.DeepEqual(peers[0], r.Leader) && + r.Buckets == nil + }) + case 2: + region := regions[0] + testutil.Eventually(re, func() bool { + r, err := suite.client.GetRegionByID(ctx, region.GetId()) + if err != nil { + re.ErrorContains(err, context.Canceled.Error()) + } + if r == nil { + return false + } + return reflect.DeepEqual(region, r.Meta) && + reflect.DeepEqual(peers[0], r.Leader) && + r.Buckets == nil + }) + } + }() + } +} + +func (suite *routerClientSuite) TestDynamicallyEnableRouterClient() { + re := suite.Require() + ctx, cancel := context.WithCancel(suite.ctx) + defer cancel() + + wg := sync.WaitGroup{} + for _, enabled := range []bool{!suite.routerClientEnabled, suite.routerClientEnabled} { + suite.dispatchConcurrentRequests(ctx, re, &wg) + wg.Wait() + err := suite.client.UpdateOption(opt.EnableRouterClient, enabled) + re.NoError(err) + } +} + +func (suite *routerClientSuite) TestConcurrentlyEnableRouterClient() { + re := suite.Require() + ctx, cancel := context.WithCancel(suite.ctx) + defer cancel() + + wg := sync.WaitGroup{} + // Concurrently enable and disable the router client. + for _, enabled := range []bool{!suite.routerClientEnabled, suite.routerClientEnabled} { + suite.dispatchConcurrentRequests(ctx, re, &wg) + // Switch the router client option immediately right after the concurrent requests dispatch. + err := suite.client.UpdateOption(opt.EnableRouterClient, enabled) + re.NoError(err) + select { + case <-time.After(time.Second): + // Let the bullet fly for a while. + case <-ctx.Done(): + } + } + wg.Wait() +}