diff --git a/internal/pool/defaults.go b/internal/pool/defaults.go index 2591e8438..a88c9f7b3 100644 --- a/internal/pool/defaults.go +++ b/internal/pool/defaults.go @@ -1,6 +1,8 @@ package pool -const DefaultLimit = 50 +const ( + DefaultLimit = 50 +) var defaultTrace = &Trace{ OnNew: func(info *NewStartInfo) func(info *NewDoneInfo) { diff --git a/internal/pool/errors.go b/internal/pool/errors.go index 36b6526c7..24a37f31e 100644 --- a/internal/pool/errors.go +++ b/internal/pool/errors.go @@ -4,7 +4,4 @@ import ( "errors" ) -var ( - errClosedPool = errors.New("closed pool") - errItemIsNotAlive = errors.New("item is not alive") -) +var errClosedPool = errors.New("closed pool") diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 71995b054..d37fd7d75 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -2,6 +2,7 @@ package pool import ( "context" + "sync" "time" "golang.org/x/sync/errgroup" @@ -38,12 +39,32 @@ type ( createTimeout time.Duration closeTimeout time.Duration - mu xsync.Mutex - idle []PT - index map[PT]struct{} - done chan struct{} + // queue is a buffered channel that holds ready-to-use items. + // Newly created items are sent to this channel by spawner goroutine. + // getItem reads from this channel to get items for usage. + // putItems sends item to this channel when it's no longer needed. + // Len of the buffered channel should be equal to configured pool size + // (MUST NOT be less). + // If item is in this queue, then it's considered idle (not in use). + queue chan PT + + // itemTokens similarly to 'queue' is a buffered channel, and it holds 'tokens'. + // Presence of token in this channel indicates that there's requests to create item. + // Every token will eventually result in creation of new item (spawnItems makes sure of that). + // + // itemTokens must have same size as queue. + // Sum of every existing token plus sum of every existing item in any time MUST be equal + // to pool size. New token MUST be added by getItem/putItem if they discovered item in use to be + // no good and discarded it. + itemTokens chan struct{} + + done chan struct{} stats *safeStats + + spawnCancel context.CancelFunc + + wg *sync.WaitGroup } option[PT Item[T], T any] func(p *Pool[PT, T]) ) @@ -159,6 +180,15 @@ func New[PT Item[T], T any]( } } + p.queue = make(chan PT, p.limit) + p.itemTokens = make(chan struct{}, p.limit) + go func() { + // fill tokens + for i := 0; i < p.limit; i++ { + p.itemTokens <- struct{}{} + } + }() + onDone := p.trace.OnNew(&NewStartInfo{ Context: &ctx, Call: stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/pool.New"), @@ -172,16 +202,73 @@ func New[PT Item[T], T any]( p.createItem = createItemWithTimeoutHandling(p.createItem, p) - p.idle = make([]PT, 0, p.limit) - p.index = make(map[PT]struct{}, p.limit) p.stats = &safeStats{ v: stats.Stats{Limit: p.limit}, onChange: p.trace.OnChange, } + var spawnCtx context.Context + p.wg = &sync.WaitGroup{} + spawnCtx, p.spawnCancel = xcontext.WithCancel(xcontext.ValueOnly(ctx)) + p.wg.Add(1) + go p.spawnItems(spawnCtx) + return p } +// spawnItems creates one item per each available itemToken and sends new item to internal item queue. +// It ensures that pool would always have amount of connections equal to configured limit. +// If item creation ended with error it will be retried infinity with configured interval until success. +func (p *Pool[PT, T]) spawnItems(ctx context.Context) { + defer p.wg.Done() + for { + select { + case <-ctx.Done(): + return + case <-p.done: + return + case <-p.itemTokens: + // got token, must create item + createLoop: + for { + select { + case <-ctx.Done(): + return + case <-p.done: + return + default: + p.wg.Add(1) + err := p.trySpawn(ctx) + if err == nil { + break createLoop + } + } + // spawn was unsuccessful, need to try again. + // token must always result in new item and not be lost. + } + } + } +} + +func (p *Pool[PT, T]) trySpawn(ctx context.Context) error { + defer p.wg.Done() + item, err := p.createItem(ctx) + if err != nil { + return err + } + // item was created successfully, put it in queue + select { + case <-ctx.Done(): + return nil + case <-p.done: + return nil + case p.queue <- item: + p.stats.Idle().Inc() + } + + return nil +} + // defaultCreateItem returns a new item func defaultCreateItem[T any, PT Item[T]](ctx context.Context) (PT, error) { var item T @@ -247,31 +334,12 @@ func createItemWithContext[PT Item[T], T any]( return xerrors.WithStackTrace(err) } - needCloseItem := true - defer func() { - if needCloseItem { - _ = p.closeItem(ctx, newItem) - } - }() - select { case <-p.done: return xerrors.WithStackTrace(errClosedPool) case <-ctx.Done(): - p.mu.Lock() - defer p.mu.Unlock() - - if len(p.index) < p.limit { - p.idle = append(p.idle, newItem) - p.index[newItem] = struct{}{} - p.stats.Index().Inc() - needCloseItem = false - } - return xerrors.WithStackTrace(ctx.Err()) case ch <- newItem: - needCloseItem = false - return nil } } @@ -280,6 +348,10 @@ func (p *Pool[PT, T]) Stats() stats.Stats { return p.stats.Get() } +// getItem retrieves item from the queue. +// If retrieved item happens to be not alive, then it's destroyed +// and tokens queue is filled to +1 so new item can be created by spawner goroutine. +// After, the process will be repeated until alive item is retrieved. func (p *Pool[PT, T]) getItem(ctx context.Context) (_ PT, finalErr error) { onDone := p.trace.OnGet(&GetStartInfo{ Context: &ctx, @@ -295,48 +367,30 @@ func (p *Pool[PT, T]) getItem(ctx context.Context) (_ PT, finalErr error) { return nil, xerrors.WithStackTrace(err) } - select { - case <-p.done: - return nil, xerrors.WithStackTrace(errClosedPool) - case <-ctx.Done(): - return nil, xerrors.WithStackTrace(ctx.Err()) - default: - var item PT - p.mu.WithLock(func() { - if len(p.idle) > 0 { - item, p.idle = p.idle[0], p.idle[1:] - p.stats.Idle().Dec() - } - }) - - if item != nil { - if item.IsAlive() { - return item, nil - } - _ = p.closeItem(ctx, item) - p.mu.WithLock(func() { - delete(p.index, item) - }) - p.stats.Index().Dec() - } - - item, err := p.createItem(ctx) - if err != nil { - return nil, xerrors.WithStackTrace(err) - } + // get item and ensure it's alive. + // Infinite loop here guarantees that we either return alive item + // or block infinitely until we have one. + // It is assumed that calling code should use context if it wishes to time out the call. + for { + select { + case <-p.done: + return nil, xerrors.WithStackTrace(errClosedPool) + case <-ctx.Done(): + return nil, xerrors.WithStackTrace(ctx.Err()) + case item := <-p.queue: // get or wait for item + p.stats.Idle().Dec() + if item != nil { + if item.IsAlive() { + // item is alive, return it - addedToIndex := false - p.mu.WithLock(func() { - if len(p.index) < p.limit { - p.index[item] = struct{}{} - addedToIndex = true + return item, nil + } + // item is not alive + _ = p.closeItem(ctx, item) // clean up dead item } - }) - if addedToIndex { - p.stats.Index().Inc() + p.itemTokens <- struct{}{} // signal spawn goroutine to create a new item + // and try again } - - return item, nil } } @@ -358,25 +412,28 @@ func (p *Pool[PT, T]) putItem(ctx context.Context, item PT) (finalErr error) { select { case <-p.done: return xerrors.WithStackTrace(errClosedPool) + case <-ctx.Done(): + return xerrors.WithStackTrace(ctx.Err()) default: - if !item.IsAlive() { + if item.IsAlive() { + // put back in the queue + select { + case <-p.done: + return xerrors.WithStackTrace(errClosedPool) + case <-ctx.Done(): + return xerrors.WithStackTrace(ctx.Err()) + case p.queue <- item: + p.stats.Idle().Inc() + } + } else { + // item is not alive + // add token and close + p.itemTokens <- struct{}{} _ = p.closeItem(ctx, item) - - p.mu.WithLock(func() { - delete(p.index, item) - }) - p.stats.Index().Dec() - - return xerrors.WithStackTrace(errItemIsNotAlive) } - - p.mu.WithLock(func() { - p.idle = append(p.idle, item) - }) - p.stats.Idle().Inc() - - return nil } + + return nil } func (p *Pool[PT, T]) closeItem(ctx context.Context, item PT) error { @@ -412,14 +469,13 @@ func (p *Pool[PT, T]) try(ctx context.Context, f func(ctx context.Context, item return xerrors.WithStackTrace(err) } + p.stats.InUse().Inc() defer func() { _ = p.putItem(ctx, item) + p.stats.InUse().Dec() }() - p.stats.InUse().Inc() - defer p.stats.InUse().Dec() - err = f(ctx, item) if err != nil { return xerrors.WithStackTrace(err) @@ -479,17 +535,27 @@ func (p *Pool[PT, T]) Close(ctx context.Context) (finalErr error) { }) }() + // canceling spawner (and any underlying createItem calls) + p.spawnCancel() + + // Only closing done channel. + // Due to multiple senders queue is not closed here, + // we're just making sure to drain it fully to close any existing item. close(p.done) - p.mu.Lock() - defer p.mu.Unlock() + p.wg.Wait() var g errgroup.Group - for item := range p.index { - item := item - g.Go(func() error { - return item.Close(ctx) - }) +shutdownLoop: + for { + select { + case item := <-p.queue: + g.Go(func() error { + return item.Close(ctx) + }) + default: + break shutdownLoop + } } if err := g.Wait(); err != nil { return xerrors.WithStackTrace(err) diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index 6ba7036c9..7decf3811 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -45,6 +45,34 @@ func (t testItem) Close(context.Context) error { return nil } +type testItemV2 struct { + v uint32 + + closeCalls int32 + dead int32 +} + +func (t *testItemV2) IsAlive() bool { + return atomic.LoadInt32(&t.dead) == 0 +} + +func (t *testItemV2) ID() string { + return "" +} + +func (t *testItemV2) Close(context.Context) error { + atomic.AddInt32(&t.closeCalls, 1) + + return nil +} + +func (t *testItemV2) failAfter(d time.Duration) { + timer := time.NewTimer(d) + <-timer.C + atomic.CompareAndSwapInt32(&t.dead, 0, 1) + timer.Stop() +} + func TestPool(t *testing.T) { rootCtx := xtest.Context(t) t.Run("New", func(t *testing.T) { @@ -222,6 +250,7 @@ func TestPool(t *testing.T) { require.EqualValues(t, atomic.LoadInt64(&createCounter), atomic.LoadInt64(&closeCounter)) }, xtest.StopAfter(time.Second)) }) + t.Run("IsAlive", func(t *testing.T) { xtest.TestManyTimes(t, func(t testing.TB) { var ( @@ -289,6 +318,147 @@ func TestPool(t *testing.T) { wg.Wait() }, xtest.StopAfter(42*time.Second)) }) + + t.Run("Chaos", func(t *testing.T) { + // FrozenCalls makes sure that poolSize is fully utilized, + // but not above configured limit. It freezes pool items + // and checks that pool is still operational if at least + // one item remains available. + t.Run("FrozenCalls", func(t *testing.T) { + const ( + poolSize = 11 + callsAmount = 100 + callFreezeFactor = 10 // every 'callFreezeFactor' call will freeze + ) + var ( + newItems int64 + deleteItems int64 + ) + p := New(rootCtx, + WithLimit[*testItem, testItem](poolSize), + WithCreateFunc(func(context.Context) (*testItem, error) { + atomic.AddInt64(&newItems, 1) + v := &testItem{ + onClose: func() error { + atomic.AddInt64(&deleteItems, 1) + + return nil + }, + } + + return v, nil + }), + ) + + var ( + freezeCh = make(chan struct{}) + wg = &sync.WaitGroup{} + ) + wg.Add(callsAmount - callsAmount/callFreezeFactor + 1) + go func() { + for i := 1; i <= callsAmount; i++ { + go func(ctr int) { + err := p.With(rootCtx, func(ctx context.Context, testItem *testItem) error { + if ctr%callFreezeFactor == 0 { + <-freezeCh + } + + return nil + }) + if err != nil && !xerrors.Is(err, errClosedPool, context.Canceled) { + t.Failed() + } + wg.Done() + }(i) + } + wg.Done() + }() + // everything not frozen should complete + wg.Wait() + + time.Sleep(time.Second) // ensuring item will be put back in the queue + require.Equal(t, poolSize-callsAmount/callFreezeFactor, p.stats.Get().Idle) + require.Equal(t, callsAmount/callFreezeFactor, p.stats.Get().InUse) + require.GreaterOrEqual(t, atomic.LoadInt64(&newItems), int64(poolSize)) + + // unfreeze + wg.Add(callsAmount / callFreezeFactor) + for i := 0; i < callsAmount/callFreezeFactor; i++ { + freezeCh <- struct{}{} + } + wg.Wait() + err := p.Close(rootCtx) + require.NoError(t, err) + // time.Sleep(5 * time.Second) + require.EqualValues(t, atomic.LoadInt64(&newItems), atomic.LoadInt64(&deleteItems)) + }) + + // FailingItems checks proper item recycling. Created items will fail + // at random time, and it should not affect call results. + t.Run("FailingItems", func(t *testing.T) { + const ( + poolSize = 11 + runTime = 20 // test will run for 'runTime' seconds + lifetime = 2 // item will die in [1;3) seconds + callFreq = 5 // new call will be made with 'callFreq' milliseconds interval + callTake = 4 // each call will take 'callTake' milliseconds + ) + var ( + callsCtr int64 + items = make([]*testItemV2, 0, 1000) + mx = &sync.Mutex{} + ) + p := New(rootCtx, + WithLimit[*testItemV2, testItemV2](poolSize), + WithCreateFunc(func(context.Context) (*testItemV2, error) { + v := &testItemV2{} + mx.Lock() + items = append(items, v) + mx.Unlock() + go v.failAfter(time.Duration(rand.Intn(lifetime)+1) * time.Second) //nolint:gosec + + return v, nil + }), + ) + + tickDone := time.After(runTime * time.Second) + tickCall := time.Tick(callFreq * time.Millisecond) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + for { + select { + case <-tickDone: + wg.Done() + + return + case <-tickCall: + wg.Add(1) + go func() { + err := p.With(rootCtx, func(ctx context.Context, _ *testItemV2) error { + atomic.AddInt64(&callsCtr, 1) + time.Sleep(callTake * time.Millisecond) + + return nil + }) + if err != nil && !xerrors.Is(err, errClosedPool, context.Canceled) { + t.Failed() + } + wg.Done() + }() + } + } + }() + wg.Wait() + err := p.Close(rootCtx) + require.NoError(t, err) + t.Log("created items", len(items), "calls made:", callsCtr) + // ensure each item was closed, and only once + for _, item := range items { + require.Equal(t, int32(1), item.closeCalls) + } + }) + }) } func TestSafeStatsRace(t *testing.T) {