Skip to content

feat: Implement ReadN #276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions source.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ func (s *Source) Open(ctx context.Context, pos opencdc.Position) error {
return nil
}

func (s *Source) Read(ctx context.Context) (opencdc.Record, error) {
return s.iterator.Next(ctx)
func (s *Source) ReadN(ctx context.Context, n int) ([]opencdc.Record, error) {
return s.iterator.NextN(ctx, n)
}

func (s *Source) Ack(ctx context.Context, pos opencdc.Position) error {
Expand Down
6 changes: 3 additions & 3 deletions source/iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ import (

// Iterator is an object that can iterate over a queue of records.
type Iterator interface {
// Next takes and returns the next record from the queue. Next is allowed to
// block until either a record is available or the context gets canceled.
Next(context.Context) (opencdc.Record, error)
// NextN takes and returns up to n records from the queue. NextN is allowed to
// block until either at least one record is available or the context gets canceled.
NextN(context.Context, int) ([]opencdc.Record, error)
// Ack signals that a record at a specific position was successfully
// processed.
Ack(context.Context, opencdc.Position) error
Expand Down
63 changes: 45 additions & 18 deletions source/logrepl/cdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type CDCConfig struct {
}

// CDCIterator asynchronously listens for events from the logical replication
// slot and returns them to the caller through Next.
// slot and returns them to the caller through NextN.
type CDCIterator struct {
config CDCConfig
records chan opencdc.Record
Expand Down Expand Up @@ -113,35 +113,62 @@ func (i *CDCIterator) StartSubscriber(ctx context.Context) error {
return nil
}

// Next returns the next record retrieved from the subscription. This call will
// block until either a record is returned from the subscription, the
// subscription stops because of an error or the context gets canceled.
// Returns error when the subscription has been started.
func (i *CDCIterator) Next(ctx context.Context) (opencdc.Record, error) {
// NextN takes and returns up to n records from the queue. NextN is allowed to
// block until either at least one record is available or the context gets canceled.
func (i *CDCIterator) NextN(ctx context.Context, n int) ([]opencdc.Record, error) {
if !i.subscriberReady() {
return opencdc.Record{}, errors.New("logical replication has not been started")
return nil, errors.New("logical replication has not been started")
}

for {
if n <= 0 {
return nil, fmt.Errorf("n must be greater than 0, got %d", n)
}

recs := make([]opencdc.Record, 0, n)

// Block until at least one record is received or context is canceled
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-i.sub.Done():
if err := i.sub.Err(); err != nil {
return nil, fmt.Errorf("logical replication error: %w", err)
}
if err := ctx.Err(); err != nil {
// subscription is done because the context is canceled, we went
// into the wrong case by chance
return nil, err
}
// subscription stopped without an error and the context is still
// open, this is a strange case, shouldn't actually happen
return nil, fmt.Errorf("subscription stopped, no more data to fetch (this smells like a bug)")
case rec := <-i.records:
recs = append(recs, rec)
}

for len(recs) < n {
select {
case rec := <-i.records:
recs = append(recs, rec)
case <-ctx.Done():
return opencdc.Record{}, ctx.Err()
return nil, ctx.Err()
case <-i.sub.Done():
if err := i.sub.Err(); err != nil {
return opencdc.Record{}, fmt.Errorf("logical replication error: %w", err)
return recs, fmt.Errorf("logical replication error: %w", err)
}
if err := ctx.Err(); err != nil {
// subscription is done because the context is canceled, we went
// into the wrong case by chance
return opencdc.Record{}, err
// Return what we have with context error
return recs, err
}
// subscription stopped without an error and the context is still
// open, this is a strange case, shouldn't actually happen
return opencdc.Record{}, fmt.Errorf("subscription stopped, no more data to fetch (this smells like a bug)")
case r := <-i.records:
return r, nil
// Return what we have with subscription stopped error
return recs, fmt.Errorf("subscription stopped, no more data to fetch (this smells like a bug)")
default:
// No more records currently available
return recs, nil
}
}

return recs, nil
}

// Ack forwards the acknowledgment to the subscription.
Expand Down
192 changes: 152 additions & 40 deletions source/logrepl/cdc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func TestCDCIterator_New(t *testing.T) {
}
}

func TestCDCIterator_Next(t *testing.T) {
func TestCDCIterator_Operation_NextN(t *testing.T) {
ctx := test.Context(t)
is := is.New(t)

Expand Down Expand Up @@ -343,9 +343,11 @@ func TestCDCIterator_Next(t *testing.T) {
// fetch the change
nextCtx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
got, err := i.Next(nextCtx)
records, err := i.NextN(nextCtx, 1)
is.NoErr(err)

got := records[0]

readAt, err := got.Metadata.GetReadAt()
is.NoErr(err)
is.True(readAt.After(now)) // ReadAt should be after now
Expand All @@ -359,40 +361,6 @@ func TestCDCIterator_Next(t *testing.T) {
}
}

func TestCDCIterator_Next_Fail(t *testing.T) {
ctx := test.Context(t)

pool := test.ConnectPool(ctx, t, test.RepmgrConnString)
table := test.SetupTestTable(ctx, t, pool)

t.Run("fail when sub is done", func(t *testing.T) {
is := is.New(t)

i := testCDCIterator(ctx, t, pool, table, true)
<-i.sub.Ready()

is.NoErr(i.Teardown(ctx))

_, err := i.Next(ctx)
expectErr := "logical replication error:"

match := strings.Contains(err.Error(), expectErr)
if !match {
t.Logf("%s != %s", err.Error(), expectErr)
}
is.True(match)
})

t.Run("fail when subscriber is not started", func(t *testing.T) {
is := is.New(t)

i := testCDCIterator(ctx, t, pool, table, false)

_, nexterr := i.Next(ctx)
is.Equal(nexterr.Error(), "logical replication has not been started")
})
}

func TestCDCIterator_EnsureLSN(t *testing.T) {
ctx := test.Context(t)
is := is.New(t)
Expand All @@ -407,8 +375,11 @@ func TestCDCIterator_EnsureLSN(t *testing.T) {
VALUES (6, 'bizz', 456, false, 12.3, 14)`, table))
is.NoErr(err)

r, err := i.Next(ctx)
rr, err := i.NextN(ctx, 1)
is.NoErr(err)
is.True(len(rr) > 0)

r := rr[0]

p, err := position.ParseSDKPosition(r.Position)
is.NoErr(err)
Expand Down Expand Up @@ -485,6 +456,138 @@ func TestCDCIterator_Ack(t *testing.T) {
})
}
}
func TestCDCIterator_NextN(t *testing.T) {
ctx := test.Context(t)
pool := test.ConnectPool(ctx, t, test.RepmgrConnString)
table := test.SetupTestTable(ctx, t, pool)

t.Run("retrieve exact N records", func(t *testing.T) {
is := is.New(t)
i := testCDCIterator(ctx, t, pool, table, true)
<-i.sub.Ready()

for j := 1; j <= 3; j++ {
_, err := pool.Exec(ctx, fmt.Sprintf(`INSERT INTO %s (id, column1, column2, column3, column4, column5)
VALUES (%d, 'test-%d', %d, false, 12.3, 14)`, table, j+10, j, j*100))
is.NoErr(err)
}

var allRecords []opencdc.Record
attemptCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()

// Collect records until we have all 3
for len(allRecords) < 3 {
records, err := i.NextN(attemptCtx, 3-len(allRecords))
is.NoErr(err)
// Only proceed if we got at least one record
is.True(len(records) > 0)
allRecords = append(allRecords, records...)
}

is.Equal(len(allRecords), 3)

for j, r := range allRecords {
is.Equal(r.Operation, opencdc.OperationCreate)
is.Equal(r.Key.(opencdc.StructuredData)["id"], int64(j+11))
change := r.Payload
data := change.After.(opencdc.StructuredData)
is.Equal(data["column1"], fmt.Sprintf("test-%d", j+1))
//nolint:gosec // no risk to overflow
is.Equal(data["column2"], (int32(j)+1)*100)
}
})

t.Run("retrieve fewer records than requested", func(t *testing.T) {
is := is.New(t)
i := testCDCIterator(ctx, t, pool, table, true)
<-i.sub.Ready()

for j := 1; j <= 2; j++ {
_, err := pool.Exec(ctx, fmt.Sprintf(`INSERT INTO %s (id, column1, column2, column3, column4, column5)
VALUES (%d, 'test-%d', %d, false, 12.3, 14)`, table, j+20, j, j*100))
is.NoErr(err)
}

// Will keep calling NextN until all records are received
var records []opencdc.Record
for len(records) < 2 {
recordsTmp, err := i.NextN(ctx, 5)
is.NoErr(err)
records = append(records, recordsTmp...)
}

// nothing else to fetch
ctxWithTimeout, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
defer cancel()
_, err := i.NextN(ctxWithTimeout, 5)
is.True(errors.Is(err, context.DeadlineExceeded))

for j, r := range records {
is.Equal(r.Operation, opencdc.OperationCreate)
is.Equal(r.Key.(opencdc.StructuredData)["id"], int64(j+21))
change := r.Payload
data := change.After.(opencdc.StructuredData)
is.Equal(data["column1"], fmt.Sprintf("test-%d", j+1))
//nolint:gosec // no risk to overflow
is.Equal(data["column2"], (int32(j)+1)*100)
}
})

t.Run("context cancellation", func(t *testing.T) {
is := is.New(t)
i := testCDCIterator(ctx, t, pool, table, true)
<-i.sub.Ready()

ctxTimeout, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer cancel()

_, err := i.NextN(ctxTimeout, 5)
is.True(errors.Is(err, context.DeadlineExceeded))
})

t.Run("subscriber not started", func(t *testing.T) {
is := is.New(t)
i := testCDCIterator(ctx, t, pool, table, false)

_, err := i.NextN(ctx, 5)
is.Equal(err.Error(), "logical replication has not been started")
})

t.Run("invalid N values", func(t *testing.T) {
is := is.New(t)
i := testCDCIterator(ctx, t, pool, table, true)
<-i.sub.Ready()

_, err := i.NextN(ctx, 0)
is.True(strings.Contains(err.Error(), "n must be greater than 0"))

_, err = i.NextN(ctx, -1)
is.True(strings.Contains(err.Error(), "n must be greater than 0"))
})

t.Run("subscription termination", func(t *testing.T) {
is := is.New(t)
i := testCDCIterator(ctx, t, pool, table, true)
<-i.sub.Ready()

_, err := pool.Exec(ctx, fmt.Sprintf(`INSERT INTO %s (id, column1, column2, column3, column4, column5)
VALUES (30, 'test-1', 100, false, 12.3, 14)`, table))
is.NoErr(err)

go func() {
time.Sleep(100 * time.Millisecond)
is.NoErr(i.Teardown(ctx))
}()

records, err := i.NextN(ctx, 5)
if err != nil {
is.True(strings.Contains(err.Error(), "logical replication error"))
} else {
is.True(len(records) > 0)
}
})
}

func testCDCIterator(ctx context.Context, t *testing.T, pool *pgxpool.Pool, table string, start bool) *CDCIterator {
is := is.New(t)
Expand Down Expand Up @@ -560,8 +663,11 @@ func TestCDCIterator_Schema(t *testing.T) {
)
is.NoErr(err)

r, err := i.Next(ctx)
rr, err := i.NextN(ctx, 1)
is.NoErr(err)
is.True(len(rr) > 0)

r := rr[0]

assertPayloadSchemaOK(ctx, is, test.TestTableAvroSchemaV1, table, r)
assertKeySchemaOK(ctx, is, table, r)
Expand All @@ -580,8 +686,11 @@ func TestCDCIterator_Schema(t *testing.T) {
)
is.NoErr(err)

r, err := i.Next(ctx)
rr, err := i.NextN(ctx, 1)
is.NoErr(err)
is.True(len(rr) > 0)

r := rr[0]

assertPayloadSchemaOK(ctx, is, test.TestTableAvroSchemaV2, table, r)
assertKeySchemaOK(ctx, is, table, r)
Expand All @@ -600,8 +709,11 @@ func TestCDCIterator_Schema(t *testing.T) {
)
is.NoErr(err)

r, err := i.Next(ctx)
rr, err := i.NextN(ctx, 1)
is.NoErr(err)
is.True(len(rr) > 0)

r := rr[0]

assertPayloadSchemaOK(ctx, is, test.TestTableAvroSchemaV3, table, r)
assertKeySchemaOK(ctx, is, table, r)
Expand Down
Loading