diff --git a/destination.go b/destination.go index 0e6e4a3..02adbe9 100644 --- a/destination.go +++ b/destination.go @@ -39,7 +39,7 @@ type Destination struct { getTableName destination.TableFn conn *pgx.Conn - dbInfo *internal.DbInfo + dbInfo *internal.NumericScaleInfo stmtBuilder sq.StatementBuilderType } @@ -347,7 +347,7 @@ func (d *Destination) formatBigRat(ctx context.Context, table string, column str // we need to get the scale of the column so we that we can properly // round the result of dividing the input big.Rat's numerator and denominator. - scale, err := d.dbInfo.GetNumericColumnScale(ctx, table, column) + scale, err := d.dbInfo.Get(ctx, table, column) if err != nil { return "", fmt.Errorf("failed getting scale of numeric column: %w", err) } diff --git a/destination_integration_test.go b/destination_integration_test.go index f97e7fe..bdc3bb4 100644 --- a/destination_integration_test.go +++ b/destination_integration_test.go @@ -36,7 +36,7 @@ func TestDestination_Write(t *testing.T) { // tables with capital letters should be quoted tableName := strings.ToUpper(test.RandomIdentifier(t)) - test.SetupTestTableWithName(ctx, t, conn, tableName) + test.SetupTableWithName(ctx, t, conn, tableName) d := NewDestination() err := sdk.Util.ParseConfig( @@ -70,10 +70,11 @@ func TestDestination_Write(t *testing.T) { Key: opencdc.StructuredData{"id": 5000}, Payload: opencdc.Change{ After: opencdc.StructuredData{ + "key": []uint8("123"), "column1": "foo", "column2": 123, "column3": true, - "column4": nil, + "column4": big.NewRat(123, 10), "UppercaseColumn1": 222, }, }, @@ -88,10 +89,11 @@ func TestDestination_Write(t *testing.T) { Key: opencdc.StructuredData{"id": 5}, Payload: opencdc.Change{ After: opencdc.StructuredData{ + "key": []uint8("234"), "column1": "foo", "column2": 456, "column3": false, - "column4": nil, + "column4": big.NewRat(123, 10), "UppercaseColumn1": 333, }, }, @@ -106,10 +108,11 @@ func TestDestination_Write(t *testing.T) { Key: opencdc.StructuredData{"id": 6}, Payload: opencdc.Change{ After: opencdc.StructuredData{ + "key": []uint8("345"), "column1": "bar", "column2": 567, "column3": true, - "column4": nil, + "column4": big.NewRat(123, 10), "UppercaseColumn1": 444, }, }, @@ -124,10 +127,11 @@ func TestDestination_Write(t *testing.T) { Key: opencdc.StructuredData{"id": 1}, Payload: opencdc.Change{ After: opencdc.StructuredData{ + "key": []uint8("456"), "column1": "foobar", "column2": 567, "column3": true, - "column4": nil, + "column4": big.NewRat(123, 10), "UppercaseColumn1": 555, }, }, @@ -151,6 +155,7 @@ func TestDestination_Write(t *testing.T) { Key: opencdc.StructuredData{"id": 123}, Payload: opencdc.Change{ After: opencdc.StructuredData{ + "key": []uint8("567"), "column1": "abcdef", "column2": 567, "column3": true, @@ -179,9 +184,7 @@ func TestDestination_Write(t *testing.T) { cmp.Diff( tt.record.Payload.After, got, - cmp.Comparer(func(x, y *big.Rat) bool { - return x.Cmp(y) == 0 - }), + test.BigRatComparer, ), ) // -want, +got case opencdc.OperationDelete: @@ -197,7 +200,7 @@ func TestDestination_Batch(t *testing.T) { conn := test.ConnectSimple(ctx, t, test.RegularConnString) tableName := strings.ToUpper(test.RandomIdentifier(t)) - test.SetupTestTableWithName(ctx, t, conn, tableName) + test.SetupTableWithName(ctx, t, conn, tableName) d := NewDestination() @@ -223,10 +226,11 @@ func TestDestination_Batch(t *testing.T) { Key: opencdc.StructuredData{"id": 5}, Payload: opencdc.Change{ After: opencdc.StructuredData{ + "key": []uint8("123"), "column1": "foo1", "column2": 1, "column3": false, - "column4": nil, + "column4": big.NewRat(123, 10), "UppercaseColumn1": 111, }, }, @@ -237,10 +241,11 @@ func TestDestination_Batch(t *testing.T) { Key: opencdc.StructuredData{"id": 6}, Payload: opencdc.Change{ After: opencdc.StructuredData{ + "key": []uint8("234"), "column1": "foo2", "column2": 2, "column3": true, - "column4": nil, + "column4": big.NewRat(123, 10), "UppercaseColumn1": 222, }, }, @@ -251,10 +256,11 @@ func TestDestination_Batch(t *testing.T) { Key: opencdc.StructuredData{"id": 7}, Payload: opencdc.Change{ After: opencdc.StructuredData{ + "key": []uint8("345"), "column1": "foo3", "column2": 3, "column3": false, - "column4": nil, + "column4": big.NewRat(123, 10), "UppercaseColumn1": 333, }, }, @@ -275,11 +281,12 @@ func TestDestination_Batch(t *testing.T) { func queryTestTable(ctx context.Context, conn test.Querier, tableName string, id any) (opencdc.StructuredData, error) { row := conn.QueryRow( ctx, - fmt.Sprintf(`SELECT column1, column2, column3, column4, "UppercaseColumn1" FROM %q WHERE id = $1`, tableName), + fmt.Sprintf(`SELECT key, column1, column2, column3, column4, "UppercaseColumn1" FROM %q WHERE id = $1`, tableName), id, ) var ( + key []uint8 col1 string col2 int col3 bool @@ -287,7 +294,7 @@ func queryTestTable(ctx context.Context, conn test.Querier, tableName string, id uppercaseCol1 int ) - err := row.Scan(&col1, &col2, &col3, &col4Str, &uppercaseCol1) + err := row.Scan(&key, &col1, &col2, &col3, &col4Str, &uppercaseCol1) if err != nil { return nil, err } @@ -301,6 +308,7 @@ func queryTestTable(ctx context.Context, conn test.Querier, tableName string, id } return opencdc.StructuredData{ + "key": key, "column1": col1, "column2": col2, "column3": col3, diff --git a/go.mod b/go.mod index a5a63eb..21ad3bf 100644 --- a/go.mod +++ b/go.mod @@ -61,7 +61,7 @@ require ( github.com/conduitio/evolviconf/evolviyaml v0.1.0 // indirect github.com/conduitio/yaml/v3 v3.3.0 // indirect github.com/curioswitch/go-reassign v0.3.0 // indirect - github.com/daixiang0/gci v0.13.6 // indirect + github.com/daixiang0/gci v0.13.5 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/denis-tingaikin/go-header v0.5.0 // indirect github.com/ettle/strcase v0.2.0 // indirect diff --git a/go.sum b/go.sum index 0b373ba..cc396f2 100644 --- a/go.sum +++ b/go.sum @@ -96,8 +96,8 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSV github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/curioswitch/go-reassign v0.3.0 h1:dh3kpQHuADL3cobV/sSGETA8DOv457dwl+fbBAhrQPs= github.com/curioswitch/go-reassign v0.3.0/go.mod h1:nApPCCTtqLJN/s8HfItCcKV0jIPwluBOvZP+dsJGA88= -github.com/daixiang0/gci v0.13.6 h1:RKuEOSkGpSadkGbvZ6hJ4ddItT3cVZ9Vn9Rybk6xjl8= -github.com/daixiang0/gci v0.13.6/go.mod h1:12etP2OniiIdP4q+kjUGrC/rUagga7ODbqsom5Eo5Yk= +github.com/daixiang0/gci v0.13.5 h1:kThgmH1yBmZSBCh1EJVxQ7JsHpm5Oms0AMed/0LaH4c= +github.com/daixiang0/gci v0.13.5/go.mod h1:12etP2OniiIdP4q+kjUGrC/rUagga7ODbqsom5Eo5Yk= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= diff --git a/internal/db_info.go b/internal/numeric_scale_info.go similarity index 75% rename from internal/db_info.go rename to internal/numeric_scale_info.go index 30394bd..d7f12fd 100644 --- a/internal/db_info.go +++ b/internal/numeric_scale_info.go @@ -22,28 +22,28 @@ import ( "github.com/jackc/pgx/v5" ) -// DbInfo provides information about tables in a database. -type DbInfo struct { +// NumericScaleInfo provides information about the scale of numeric columns +// in a database. +type NumericScaleInfo struct { conn *pgx.Conn cache map[string]*tableCache } // tableCache stores information about a table. -// The information is cached and refreshed every 'cacheExpiration'. type tableCache struct { columns map[string]int } -func NewDbInfo(conn *pgx.Conn) *DbInfo { - return &DbInfo{ +func NewDbInfo(conn *pgx.Conn) *NumericScaleInfo { + return &NumericScaleInfo{ conn: conn, cache: map[string]*tableCache{}, } } -func (d *DbInfo) GetNumericColumnScale(ctx context.Context, table string, column string) (int, error) { - // Check if table exists in cache and is not expired - tableInfo, ok := d.cache[table] +func (i *NumericScaleInfo) Get(ctx context.Context, table string, column string) (int, error) { + // Check if table exists in cache + tableInfo, ok := i.cache[table] if ok { scale, ok := tableInfo.columns[column] if ok { @@ -51,23 +51,23 @@ func (d *DbInfo) GetNumericColumnScale(ctx context.Context, table string, column } } else { // Table info has expired, refresh the cache - d.cache[table] = &tableCache{ + i.cache[table] = &tableCache{ columns: map[string]int{}, } } // Fetch scale from database - scale, err := d.numericScaleFromDb(ctx, table, column) + scale, err := i.fetchFromDB(ctx, table, column) if err != nil { return 0, err } - d.cache[table].columns[column] = scale + i.cache[table].columns[column] = scale return scale, nil } -func (d *DbInfo) numericScaleFromDb(ctx context.Context, table string, column string) (int, error) { +func (i *NumericScaleInfo) fetchFromDB(ctx context.Context, table string, column string) (int, error) { // Query to get the column type and numeric scale query := ` SELECT @@ -83,7 +83,7 @@ func (d *DbInfo) numericScaleFromDb(ctx context.Context, table string, column st var dataType string var numericScale *int - err := d.conn.QueryRow(ctx, query, table, column).Scan(&dataType, &numericScale) + err := i.conn.QueryRow(ctx, query, table, column).Scan(&dataType, &numericScale) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return 0, fmt.Errorf("column %s not found in table %s", column, table) diff --git a/internal/table_info.go b/internal/table_info.go new file mode 100644 index 0000000..d1a191c --- /dev/null +++ b/internal/table_info.go @@ -0,0 +1,114 @@ +// Copyright © 2024 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "fmt" + + sdk "github.com/conduitio/conduit-connector-sdk" + "github.com/jackc/pgx/v5/pgxpool" +) + +type TableInfo struct { + Name string + Columns map[string]*ColumnInfo +} + +func NewTableInfo(tableName string) *TableInfo { + return &TableInfo{ + Name: tableName, + Columns: make(map[string]*ColumnInfo), + } +} + +type ColumnInfo struct { + IsNotNull bool +} + +type TableInfoFetcher struct { + connPool *pgxpool.Pool + tableInfo map[string]*TableInfo +} + +func NewTableInfoFetcher(connPool *pgxpool.Pool) *TableInfoFetcher { + return &TableInfoFetcher{ + connPool: connPool, + tableInfo: make(map[string]*TableInfo), + } +} + +func (i TableInfoFetcher) Refresh(ctx context.Context, tableName string) error { + tx, err := i.connPool.Begin(ctx) + if err != nil { + return fmt.Errorf("failed to start tx for getting table info: %w", err) + } + defer func() { + if err := tx.Rollback(ctx); err != nil { + sdk.Logger(ctx).Warn(). + Err(err). + Msgf("error on tx rollback for getting table info") + } + }() + + query := ` + SELECT a.attname as column_name, a.attnotnull as is_not_null + FROM pg_catalog.pg_attribute a + WHERE a.attrelid = $1::regclass + AND a.attnum > 0 + AND NOT a.attisdropped + ORDER BY a.attnum; + ` + + rows, err := tx.Query(ctx, query, WrapSQLIdent(tableName)) + if err != nil { + sdk.Logger(ctx). + Err(err). + Str("query", query). + Msgf("failed to execute table info query") + + return fmt.Errorf("failed to get table info: %w", err) + } + defer rows.Close() + + ti := NewTableInfo(tableName) + for rows.Next() { + var columnName string + var isNotNull bool + + err := rows.Scan(&columnName, &isNotNull) + if err != nil { + return fmt.Errorf("failed to scan table info row: %w", err) + } + + ci := ti.Columns[columnName] + if ci == nil { + ci = &ColumnInfo{} + ti.Columns[columnName] = ci + } + ci.IsNotNull = isNotNull + } + + if err := rows.Err(); err != nil { + return fmt.Errorf("failed to get table info rows: %w", err) + } + + i.tableInfo[tableName] = ti + return nil +} + +func (i TableInfoFetcher) GetTable(name string) *TableInfo { + return i.tableInfo[name] +} diff --git a/source/logrepl/cdc.go b/source/logrepl/cdc.go index 2f4faab..c3306c5 100644 --- a/source/logrepl/cdc.go +++ b/source/logrepl/cdc.go @@ -21,6 +21,7 @@ import ( "time" "github.com/conduitio/conduit-commons/opencdc" + internal2 "github.com/conduitio/conduit-connector-postgres/internal" "github.com/conduitio/conduit-connector-postgres/source/logrepl/internal" "github.com/conduitio/conduit-connector-postgres/source/position" sdk "github.com/conduitio/conduit-connector-sdk" @@ -87,6 +88,7 @@ func NewCDCIterator(ctx context.Context, pool *pgxpool.Pool, c CDCConfig) (*CDCI handler := NewCDCHandler( ctx, internal.NewRelationSet(), + internal2.NewTableInfoFetcher(pool), c.TableKeys, batchesCh, c.WithAvroSchema, @@ -141,9 +143,8 @@ func (i *CDCIterator) StartSubscriber(ctx context.Context) error { return nil } -// NextN returns up to n records from the internal channel with records. -// NextN is allowed to block until either at least one record is available -// or the context gets canceled. +// NextN takes and returns up to n records from the queue. NextN is allowed to +// block until either at least one record is available or the context gets canceled. func (i *CDCIterator) NextN(ctx context.Context, n int) ([]opencdc.Record, error) { if !i.subscriberReady() { return nil, errors.New("logical replication has not been started") diff --git a/source/logrepl/cdc_test.go b/source/logrepl/cdc_test.go index 487117d..bbb9a33 100644 --- a/source/logrepl/cdc_test.go +++ b/source/logrepl/cdc_test.go @@ -49,7 +49,7 @@ func TestCDCIterator_New(t *testing.T) { name: "publication already exists", setup: func(t *testing.T) CDCConfig { is := is.New(t) - table := test.SetupTestTable(ctx, t, pool) + table := test.SetupTable(ctx, t, pool) test.CreatePublication(t, pool, table, []string{table}) t.Cleanup(func() { @@ -79,7 +79,7 @@ func TestCDCIterator_New(t *testing.T) { name: "fails to create subscription", setup: func(t *testing.T) CDCConfig { is := is.New(t) - table := test.SetupTestTable(ctx, t, pool) + table := test.SetupTable(ctx, t, pool) t.Cleanup(func() { is.NoErr(Cleanup(ctx, CleanupConfig{ @@ -125,7 +125,7 @@ func TestCDCIterator_Operation_NextN(t *testing.T) { is := is.New(t) pool := test.ConnectPool(ctx, t, test.RepmgrConnString) - table := test.SetupTestTable(ctx, t, pool) + table := test.SetupTable(ctx, t, pool) i := testCDCIterator(ctx, t, pool, table, true) // wait for subscription to be ready @@ -141,8 +141,8 @@ func TestCDCIterator_Operation_NextN(t *testing.T) { name: "should detect insert", setup: func(t *testing.T) { is := is.New(t) - query := fmt.Sprintf(`INSERT INTO %s (id, column1, column2, column3, column4, column5, column6, column7) - VALUES (6, 'bizz', 456, false, 12.3, 14, '{"foo2": "bar2"}', '{"foo2": "baz2"}')`, table) + query := fmt.Sprintf(`INSERT INTO %s (id, key, column1, column2, column3, column4, "UppercaseColumn1") + VALUES (6, '6', 'bizz', 456, false, 12.3, 61)`, table) _, err := pool.Exec(ctx, query) is.NoErr(err) }, @@ -165,11 +165,8 @@ func TestCDCIterator_Operation_NextN(t *testing.T) { "column2": int32(456), "column3": false, "column4": big.NewRat(123, 10), - "column5": big.NewRat(14, 1), - "column6": []byte(`{"foo2": "bar2"}`), - "column7": []byte(`{"foo2": "baz2"}`), - "key": nil, - "UppercaseColumn1": nil, + "key": []uint8("6"), + "UppercaseColumn1": int32(61), }, }, }, @@ -200,9 +197,6 @@ func TestCDCIterator_Operation_NextN(t *testing.T) { "column2": int32(123), "column3": false, "column4": big.NewRat(122, 10), - "column5": big.NewRat(4, 1), - "column6": []byte(`{"foo": "bar"}`), - "column7": []byte(`{"foo": "baz"}`), "key": []uint8("1"), "UppercaseColumn1": int32(1), }, @@ -237,9 +231,6 @@ func TestCDCIterator_Operation_NextN(t *testing.T) { "column2": int32(123), "column3": false, "column4": big.NewRat(122, 10), - "column5": big.NewRat(4, 1), - "column6": []byte(`{"foo": "bar"}`), - "column7": []byte(`{"foo": "baz"}`), "key": []uint8("1"), "UppercaseColumn1": int32(1), }, @@ -249,9 +240,6 @@ func TestCDCIterator_Operation_NextN(t *testing.T) { "column2": int32(123), "column3": false, "column4": big.NewRat(122, 10), - "column5": big.NewRat(4, 1), - "column6": []byte(`{"foo": "bar"}`), - "column7": []byte(`{"foo": "baz"}`), "key": []uint8("1"), "UppercaseColumn1": int32(1), }, @@ -286,9 +274,6 @@ func TestCDCIterator_Operation_NextN(t *testing.T) { "column2": nil, "column3": nil, "column4": nil, - "column5": nil, - "column6": nil, - "column7": nil, "key": nil, "UppercaseColumn1": nil, }, @@ -323,10 +308,7 @@ func TestCDCIterator_Operation_NextN(t *testing.T) { "column1": "baz", "column2": int32(789), "column3": false, - "column4": nil, - "column5": big.NewRat(9, 1), - "column6": []byte(`{"foo": "bar"}`), - "column7": []byte(`{"foo": "baz"}`), + "column4": big.NewRat(836, 25), "UppercaseColumn1": int32(3), }, }, @@ -360,9 +342,7 @@ func TestCDCIterator_Operation_NextN(t *testing.T) { tt.want, got, cmpopts.IgnoreUnexported(opencdc.Record{}), - cmp.Comparer(func(x, y *big.Rat) bool { - return x.Cmp(y) == 0 - }), + test.BigRatComparer, )) is.NoErr(i.Ack(ctx, got.Position)) }) @@ -374,13 +354,13 @@ func TestCDCIterator_EnsureLSN(t *testing.T) { is := is.New(t) pool := test.ConnectPool(ctx, t, test.RepmgrConnString) - table := test.SetupTestTable(ctx, t, pool) + table := test.SetupTable(ctx, t, pool) i := testCDCIterator(ctx, t, pool, table, true) <-i.sub.Ready() - _, err := pool.Exec(ctx, fmt.Sprintf(`INSERT INTO %s (id, column1, column2, column3, column4, column5) - VALUES (6, 'bizz', 456, false, 12.3, 14)`, table)) + _, err := pool.Exec(ctx, fmt.Sprintf(`INSERT INTO %s (id, key, column1, column2, column3, column4, "UppercaseColumn1") + VALUES (6, '6', 'bizz', 456, false, 12.3, 6)`, table)) is.NoErr(err) rr, err := i.NextN(ctx, 1) @@ -468,7 +448,7 @@ func TestCDCIterator_Ack(t *testing.T) { func TestCDCIterator_NextN_InternalBatching(t *testing.T) { ctx := test.Context(t) pool := test.ConnectPool(ctx, t, test.RepmgrConnString) - table := test.SetupEmptyTestTable(ctx, t, pool) + table := test.SetupEmptyTable(ctx, t, pool) is := is.New(t) underTest := testCDCIterator(ctx, t, pool, table, true) @@ -503,8 +483,8 @@ func insertTestRows(ctx context.Context, is *is.I, pool *pgxpool.Pool, table str _, err := pool.Exec( ctx, fmt.Sprintf( - `INSERT INTO %s (id, column1, column2, column3, column4, column5) - VALUES (%d, 'test-%d', %d, false, 12.3, 14)`, table, i+10, i, i*100, + `INSERT INTO %s (id, key, column1, column2, column3, column4, "UppercaseColumn1") + VALUES (%d, '%d', 'test-%d', %d, false, 12.3, %d)`, table, i+10, i+10, i, i*100, i+10, ), ) is.NoErr(err) @@ -512,6 +492,8 @@ func insertTestRows(ctx context.Context, is *is.I, pool *pgxpool.Pool, table str } func verifyOpenCDCRecords(is *is.I, got []opencdc.Record, tableName string, fromID, toID int) { + is.Helper() + // Build the expected records slice var want []opencdc.Record @@ -524,16 +506,15 @@ func verifyOpenCDCRecords(is *is.I, got []opencdc.Record, tableName string, from }, Payload: opencdc.Change{ After: opencdc.StructuredData{ - "id": id, - "key": nil, - "column1": fmt.Sprintf("test-%d", i), - "column2": int32(i) * 100, //nolint:gosec // fine, we know the value is small enough - "column3": false, - "column4": big.NewRat(123, 10), - "column5": big.NewRat(14, 1), - "column6": nil, - "column7": nil, - "UppercaseColumn1": nil, + "id": id, + "key": []uint8(fmt.Sprintf("%d", id)), + "column1": fmt.Sprintf("test-%d", i), + "column2": int32(i) * 100, //nolint:gosec // fine, we know the value is small enough + "column3": false, + "column4": big.NewRat(123, 10), + // UppercaseColumn1 is a Postgres integer (4 bytes) + //nolint:gosec // integer overflow not possible, id is a small value always + "UppercaseColumn1": int32(id), }, }, Metadata: opencdc.Metadata{ @@ -547,9 +528,7 @@ func verifyOpenCDCRecords(is *is.I, got []opencdc.Record, tableName string, from cmpOpts := []cmp.Option{ cmpopts.IgnoreUnexported(opencdc.Record{}), cmpopts.IgnoreFields(opencdc.Record{}, "Position", "Metadata"), - cmp.Comparer(func(x, y *big.Rat) bool { - return x.Cmp(y) == 0 - }), + test.BigRatComparer, } is.Equal("", cmp.Diff(want, got, cmpOpts...)) // mismatch (-want +got) } @@ -557,7 +536,7 @@ func verifyOpenCDCRecords(is *is.I, got []opencdc.Record, tableName string, from func TestCDCIterator_NextN(t *testing.T) { ctx := test.Context(t) pool := test.ConnectPool(ctx, t, test.RepmgrConnString) - table := test.SetupTestTable(ctx, t, pool) + table := test.SetupTable(ctx, t, pool) t.Run("retrieve exact N records", func(t *testing.T) { is := is.New(t) @@ -565,8 +544,8 @@ func TestCDCIterator_NextN(t *testing.T) { <-i.sub.Ready() for j := 1; j <= 3; j++ { - _, err := pool.Exec(ctx, fmt.Sprintf(`INSERT INTO %s (id, column1, column2, column3, column4, column5) - VALUES (%d, 'test-%d', %d, false, 12.3, 14)`, table, j+10, j, j*100)) + _, err := pool.Exec(ctx, fmt.Sprintf(`INSERT INTO %s (id, key, column1, column2, column3, column4, "UppercaseColumn1") + VALUES (%d, '%d', 'test-%d', %d, false, 12.3, 4)`, table, j+10, j+10, j, j*100)) is.NoErr(err) } @@ -588,8 +567,7 @@ func TestCDCIterator_NextN(t *testing.T) { for j, r := range allRecords { is.Equal(r.Operation, opencdc.OperationCreate) is.Equal(r.Key.(opencdc.StructuredData)["id"], int64(j+11)) - change := r.Payload - data := change.After.(opencdc.StructuredData) + data := r.Payload.After.(opencdc.StructuredData) is.Equal(data["column1"], fmt.Sprintf("test-%d", j+1)) //nolint:gosec // no risk to overflow is.Equal(data["column2"], (int32(j)+1)*100) @@ -602,8 +580,8 @@ func TestCDCIterator_NextN(t *testing.T) { <-i.sub.Ready() for j := 1; j <= 2; j++ { - _, err := pool.Exec(ctx, fmt.Sprintf(`INSERT INTO %s (id, column1, column2, column3, column4, column5) - VALUES (%d, 'test-%d', %d, false, 12.3, 14)`, table, j+20, j, j*100)) + _, err := pool.Exec(ctx, fmt.Sprintf(`INSERT INTO %s (id, key, column1, column2, column3, column4, "UppercaseColumn1") + VALUES (%d, '%d', 'test-%d', %d, false, 12.3, 4)`, table, j+20, j+20, j, j*100)) is.NoErr(err) } @@ -663,23 +641,6 @@ func TestCDCIterator_NextN(t *testing.T) { _, err = i.NextN(ctx, -1) is.True(strings.Contains(err.Error(), "n must be greater than 0")) }) - - t.Run("subscription termination", func(t *testing.T) { - is := is.New(t) - i := testCDCIterator(ctx, t, pool, table, true) - <-i.sub.Ready() - - _, err := pool.Exec(ctx, fmt.Sprintf(`INSERT INTO %s (id, column1, column2, column3, column4, column5) - VALUES (30, 'test-1', 100, false, 12.3, 14)`, table)) - is.NoErr(err) - - time.Sleep(100 * time.Millisecond) - is.NoErr(i.Teardown(ctx)) - - _, err = i.NextN(ctx, 5) - is.True(err != nil) - is.True(strings.Contains(err.Error(), "logical replication error")) - }) } func testCDCIterator(ctx context.Context, t *testing.T, pool *pgxpool.Pool, table string, start bool) *CDCIterator { @@ -742,7 +703,7 @@ func TestCDCIterator_Schema(t *testing.T) { ctx := test.Context(t) pool := test.ConnectPool(ctx, t, test.RepmgrConnString) - table := test.SetupTestTable(ctx, t, pool) + table := test.SetupTable(ctx, t, pool) i := testCDCIterator(ctx, t, pool, table, true) <-i.sub.Ready() @@ -752,8 +713,8 @@ func TestCDCIterator_Schema(t *testing.T) { _, err := pool.Exec( ctx, - fmt.Sprintf(`INSERT INTO %s (id, column1, column2, column3, column4, column5) - VALUES (6, 'bizz', 456, false, 12.3, 14)`, table), + fmt.Sprintf(`INSERT INTO %s (id, key, column1, column2, column3, column4, "UppercaseColumn1") + VALUES (6, '6', 'bizz', 456, false, 12.3, 6)`, table), ) is.NoErr(err) @@ -775,8 +736,8 @@ func TestCDCIterator_Schema(t *testing.T) { _, err = pool.Exec( ctx, - fmt.Sprintf(`INSERT INTO %s (id, key, column1, column2, column3, column4, column5, column6, column7, column101) - VALUES (7, decode('aabbcc', 'hex'), 'example data 1', 100, true, 12345.678, 12345, '{"foo":"bar"}', '{"foo2":"baz2"}', '2023-09-09 10:00:00');`, table), + fmt.Sprintf(`INSERT INTO %s (id, key, column1, column2, column3, column4, column101, "UppercaseColumn1") + VALUES (7, decode('aabbcc', 'hex'), 'example data 1', 100, true, 12345.678, '2023-09-09 10:00:00', 7);`, table), ) is.NoErr(err) @@ -793,13 +754,13 @@ func TestCDCIterator_Schema(t *testing.T) { t.Run("column removed", func(t *testing.T) { is := is.New(t) - _, err := pool.Exec(ctx, fmt.Sprintf(`ALTER TABLE %s DROP COLUMN column4, DROP COLUMN column5;`, table)) + _, err := pool.Exec(ctx, fmt.Sprintf(`ALTER TABLE %s DROP COLUMN column4;`, table)) is.NoErr(err) _, err = pool.Exec( ctx, - fmt.Sprintf(`INSERT INTO %s (id, key, column1, column2, column3, column6, column7, column101) - VALUES (8, decode('aabbcc', 'hex'), 'example data 1', 100, true, '{"foo":"bar"}', '{"foo2":"baz2"}', '2023-09-09 10:00:00');`, table), + fmt.Sprintf(`INSERT INTO %s (id, key, column1, column2, column3, column101, "UppercaseColumn1") + VALUES (8, decode('aabbcc', 'hex'), 'example data 1', 100, true, '2023-09-09 10:00:00', 8);`, table), ) is.NoErr(err) diff --git a/source/logrepl/cleaner_test.go b/source/logrepl/cleaner_test.go index 7c67058..c8cf121 100644 --- a/source/logrepl/cleaner_test.go +++ b/source/logrepl/cleaner_test.go @@ -42,7 +42,7 @@ func Test_Cleanup(t *testing.T) { PublicationName: "conduitpub1", }, setup: func(t *testing.T) { - table := test.SetupTestTable(context.Background(), t, conn) + table := test.SetupTable(context.Background(), t, conn) test.CreatePublication(t, conn, "conduitpub1", []string{table}) test.CreateReplicationSlot(t, conn, "conduitslot1") }, @@ -54,7 +54,7 @@ func Test_Cleanup(t *testing.T) { PublicationName: "conduitpub2", }, setup: func(t *testing.T) { - table := test.SetupTestTable(context.Background(), t, conn) + table := test.SetupTable(context.Background(), t, conn) test.CreatePublication(t, conn, "conduitpub2", []string{table}) }, }, @@ -76,7 +76,7 @@ func Test_Cleanup(t *testing.T) { PublicationName: "conduitpub4", }, setup: func(t *testing.T) { - table := test.SetupTestTable(context.Background(), t, conn) + table := test.SetupTable(context.Background(), t, conn) test.CreatePublication(t, conn, "conduitpub4", []string{table}) }, wantErr: errors.New(`replication slot "conduitslot4" does not exist`), diff --git a/source/logrepl/combined_test.go b/source/logrepl/combined_test.go index d48976d..82bc5bd 100644 --- a/source/logrepl/combined_test.go +++ b/source/logrepl/combined_test.go @@ -50,7 +50,7 @@ func TestConfig_Validate(t *testing.T) { func TestCombinedIterator_New(t *testing.T) { ctx := test.Context(t) pool := test.ConnectPool(ctx, t, test.RepmgrConnString) - table := test.SetupTestTable(ctx, t, pool) + table := test.SetupTable(ctx, t, pool) t.Run("fails to parse initial position", func(t *testing.T) { is := is.New(t) @@ -145,7 +145,7 @@ func TestCombinedIterator_NextN(t *testing.T) { is := is.New(t) pool := test.ConnectPool(ctx, t, test.RepmgrConnString) - table := test.SetupTestTable(ctx, t, pool) + table := test.SetupTable(ctx, t, pool) i, err := NewCombinedIterator(ctx, pool, Config{ Position: opencdc.Position{}, Tables: []string{table}, @@ -158,8 +158,8 @@ func TestCombinedIterator_NextN(t *testing.T) { // Add a record to the table for CDC mode testing _, err = pool.Exec(ctx, fmt.Sprintf( - `INSERT INTO %s (id, column1, column2, column3, column4, column5, column6, column7) - VALUES (6, 'bizz', 1010, false, 872.2, 101, '{"foo12": "bar12"}', '{"foo13": "bar13"}')`, + `INSERT INTO %s (id, key, column1, column2, column3, column4, "UppercaseColumn1") + VALUES (6, '6', 'bizz', 1010, false, 872.2, 6)`, table, )) is.NoErr(err) @@ -276,15 +276,15 @@ func TestCombinedIterator_NextN(t *testing.T) { // Insert two more records for testing CDC batch _, err = pool.Exec(ctx, fmt.Sprintf( - `INSERT INTO %s (id, column1, column2, column3, column4, column5, column6, column7) - VALUES (7, 'buzz', 10101, true, 121.9, 51, '{"foo7": "bar7"}', '{"foo8": "bar8"}')`, + `INSERT INTO %s (id, key, column1, column2, column3, column4, "UppercaseColumn1") + VALUES (7, '7', 'buzz', 10101, true, 121.9, 7)`, table, )) is.NoErr(err) _, err = pool.Exec(ctx, fmt.Sprintf( - `INSERT INTO %s (id, column1, column2, column3, column4, column5, column6, column7) - VALUES (8, 'fizz', 20202, false, 232.8, 62, '{"foo9": "bar9"}', '{"foo10": "bar10"}')`, + `INSERT INTO %s (id, key, column1, column2, column3, column4, "UppercaseColumn1") + VALUES (8, '8', 'fizz', 20202, false, 232.8, 8)`, table, )) is.NoErr(err) @@ -320,9 +320,7 @@ func TestCombinedIterator_NextN(t *testing.T) { is.Equal("", cmp.Diff( expectedRecords[6], records[0].Payload.After.(opencdc.StructuredData), - cmp.Comparer(func(x, y *big.Rat) bool { - return x.Cmp(y) == 0 - }), + test.BigRatComparer, )) is.NoErr(i.Ack(ctx, records[0].Position)) @@ -369,9 +367,6 @@ func testRecords() []opencdc.StructuredData { "column2": int32(123), "column3": false, "column4": big.NewRat(122, 10), - "column5": big.NewRat(4, 1), - "column6": []byte(`{"foo": "bar"}`), - "column7": []byte(`{"foo": "baz"}`), "UppercaseColumn1": int32(1), }, { @@ -393,9 +388,6 @@ func testRecords() []opencdc.StructuredData { "column2": int32(789), "column3": false, "column4": nil, - "column5": big.NewRat(9, 1), - "column6": []byte(`{"foo": "bar"}`), - "column7": []byte(`{"foo": "baz"}`), "UppercaseColumn1": int32(3), }, { @@ -405,34 +397,25 @@ func testRecords() []opencdc.StructuredData { "column2": nil, "column3": nil, "column4": big.NewRat(911, 10), // 91.1 - "column5": nil, - "column6": nil, - "column7": nil, - "UppercaseColumn1": nil, + "UppercaseColumn1": int32(4), }, { "id": int64(6), - "key": nil, + "key": []uint8("6"), "column1": "bizz", "column2": int32(1010), "column3": false, "column4": big.NewRat(8722, 10), // 872.2 - "column5": big.NewRat(101, 1), - "column6": []byte(`{"foo12": "bar12"}`), - "column7": []byte(`{"foo13": "bar13"}`), - "UppercaseColumn1": nil, + "UppercaseColumn1": int32(6), }, { "id": int64(7), - "key": nil, + "key": []uint8("7"), "column1": "buzz", "column2": int32(10101), "column3": true, "column4": big.NewRat(1219, 10), // 121.9 - "column5": big.NewRat(51, 1), - "column6": []byte(`{"foo7": "bar7"}`), - "column7": []byte(`{"foo8": "bar8"}`), - "UppercaseColumn1": nil, + "UppercaseColumn1": int32(7), }, } } diff --git a/source/logrepl/handler.go b/source/logrepl/handler.go index b4039db..6754ea5 100644 --- a/source/logrepl/handler.go +++ b/source/logrepl/handler.go @@ -22,6 +22,7 @@ import ( "github.com/conduitio/conduit-commons/opencdc" cschema "github.com/conduitio/conduit-commons/schema" + internal2 "github.com/conduitio/conduit-connector-postgres/internal" "github.com/conduitio/conduit-connector-postgres/source/logrepl/internal" "github.com/conduitio/conduit-connector-postgres/source/position" "github.com/conduitio/conduit-connector-postgres/source/schema" @@ -36,6 +37,8 @@ type CDCHandler struct { tableKeys map[string]string relationSet *internal.RelationSet + tableInfo *internal2.TableInfoFetcher + // batchSize is the largest number of records this handler will send at once. batchSize int flushInterval time.Duration @@ -55,6 +58,7 @@ type CDCHandler struct { func NewCDCHandler( ctx context.Context, rs *internal.RelationSet, + tableInfo *internal2.TableInfoFetcher, tableKeys map[string]string, out chan<- []opencdc.Record, withAvroSchema bool, @@ -66,6 +70,7 @@ func NewCDCHandler( relationSet: rs, recordBatch: make([]opencdc.Record, 0, batchSize), out: out, + tableInfo: tableInfo, withAvroSchema: withAvroSchema, keySchemas: make(map[string]cschema.Schema), payloadSchemas: make(map[string]cschema.Schema), @@ -126,6 +131,10 @@ func (h *CDCHandler) Handle(ctx context.Context, m pglogrepl.Message, lsn pglogr case *pglogrepl.RelationMessage: // We have to add the Relations to our Set so that we can decode our own output h.relationSet.Add(m) + err := h.tableInfo.Refresh(ctx, m.RelationName) + if err != nil { + return 0, fmt.Errorf("failed to refresh table info: %w", err) + } case *pglogrepl.InsertMessage: if err := h.handleInsert(ctx, m, lsn); err != nil { return 0, fmt.Errorf("logrepl handler insert: %w", err) @@ -164,15 +173,15 @@ func (h *CDCHandler) handleInsert( return fmt.Errorf("failed getting relation %v: %w", msg.RelationID, err) } - newValues, err := h.relationSet.Values(msg.RelationID, msg.Tuple) - if err != nil { - return fmt.Errorf("failed to decode new values: %w", err) - } - if err := h.updateAvroSchema(ctx, rel); err != nil { return fmt.Errorf("failed to update avro schema: %w", err) } + newValues, err := h.relationSet.Values(msg.RelationID, msg.Tuple, h.tableInfo.GetTable(rel.RelationName)) + if err != nil { + return fmt.Errorf("failed to decode new values: %w", err) + } + rec := sdk.Util.Source.NewRecordCreate( h.buildPosition(lsn), h.buildRecordMetadata(rel), @@ -197,7 +206,7 @@ func (h *CDCHandler) handleUpdate( return err } - newValues, err := h.relationSet.Values(msg.RelationID, msg.NewTuple) + newValues, err := h.relationSet.Values(msg.RelationID, msg.NewTuple, h.tableInfo.GetTable(rel.RelationName)) if err != nil { return fmt.Errorf("failed to decode new values: %w", err) } @@ -206,7 +215,7 @@ func (h *CDCHandler) handleUpdate( return fmt.Errorf("failed to update avro schema: %w", err) } - oldValues, err := h.relationSet.Values(msg.RelationID, msg.OldTuple) + oldValues, err := h.relationSet.Values(msg.RelationID, msg.OldTuple, h.tableInfo.GetTable(rel.RelationName)) if err != nil { // this is not a critical error, old values are optional, just log it // we use level "trace" intentionally to not clog up the logs in production @@ -238,7 +247,7 @@ func (h *CDCHandler) handleDelete( return err } - oldValues, err := h.relationSet.Values(msg.RelationID, msg.OldTuple) + oldValues, err := h.relationSet.Values(msg.RelationID, msg.OldTuple, h.tableInfo.GetTable(rel.RelationName)) if err != nil { return fmt.Errorf("failed to decode old values: %w", err) } @@ -324,7 +333,7 @@ func (h *CDCHandler) updateAvroSchema(ctx context.Context, rel *pglogrepl.Relati return nil } // Payload schema - avroPayloadSch, err := schema.Avro.ExtractLogrepl(rel.RelationName+"_payload", rel) + avroPayloadSch, err := schema.Avro.ExtractLogrepl(rel.RelationName+"_payload", rel, h.tableInfo.GetTable(rel.RelationName)) if err != nil { return fmt.Errorf("failed to extract payload schema: %w", err) } @@ -340,7 +349,7 @@ func (h *CDCHandler) updateAvroSchema(ctx context.Context, rel *pglogrepl.Relati h.payloadSchemas[rel.RelationName] = ps // Key schema - avroKeySch, err := schema.Avro.ExtractLogrepl(rel.RelationName+"_key", rel, h.tableKeys[rel.RelationName]) + avroKeySch, err := schema.Avro.ExtractLogrepl(rel.RelationName+"_key", rel, h.tableInfo.GetTable(rel.RelationName), h.tableKeys[rel.RelationName]) if err != nil { return fmt.Errorf("failed to extract key schema: %w", err) } diff --git a/source/logrepl/handler_test.go b/source/logrepl/handler_test.go index 5956377..19736b0 100644 --- a/source/logrepl/handler_test.go +++ b/source/logrepl/handler_test.go @@ -29,7 +29,7 @@ func TestHandler_Batching_BatchSizeReached(t *testing.T) { is := is.New(t) ch := make(chan []opencdc.Record, 1) - underTest := NewCDCHandler(ctx, nil, nil, ch, false, 5, time.Second) + underTest := NewCDCHandler(ctx, nil, nil, nil, ch, false, 5, time.Second) want := make([]opencdc.Record, 5) for i := 0; i < cap(want); i++ { rec := newTestRecord(i) @@ -51,7 +51,7 @@ func TestHandler_Batching_FlushInterval(t *testing.T) { ch := make(chan []opencdc.Record, 1) flushInterval := time.Second - underTest := NewCDCHandler(ctx, nil, nil, ch, false, 5, flushInterval) + underTest := NewCDCHandler(ctx, nil, nil, nil, ch, false, 5, flushInterval) want := make([]opencdc.Record, 3) for i := 0; i < cap(want); i++ { @@ -74,7 +74,7 @@ func TestHandler_Batching_ContextCancelled(t *testing.T) { is := is.New(t) ch := make(chan []opencdc.Record, 1) - underTest := NewCDCHandler(ctx, nil, nil, ch, false, 5, time.Second) + underTest := NewCDCHandler(ctx, nil, nil, nil, ch, false, 5, time.Second) cancel() <-ctx.Done() underTest.addToBatch(ctx, newTestRecord(0)) diff --git a/source/logrepl/internal/publication_test.go b/source/logrepl/internal/publication_test.go index f04a707..e4a109f 100644 --- a/source/logrepl/internal/publication_test.go +++ b/source/logrepl/internal/publication_test.go @@ -35,8 +35,8 @@ func TestCreatePublication(t *testing.T) { } tables := []string{ - test.SetupTestTable(ctx, t, pool), - test.SetupTestTable(ctx, t, pool), + test.SetupTable(ctx, t, pool), + test.SetupTable(ctx, t, pool), } for _, givenPubName := range pubNames { @@ -75,8 +75,8 @@ func TestCreatePublicationForTables(t *testing.T) { pool := test.ConnectPool(ctx, t, test.RegularConnString) tables := [][]string{ - {test.SetupTestTable(ctx, t, pool)}, - {test.SetupTestTable(ctx, t, pool), test.SetupTestTable(ctx, t, pool)}, + {test.SetupTable(ctx, t, pool)}, + {test.SetupTable(ctx, t, pool), test.SetupTable(ctx, t, pool)}, } for _, givenTables := range tables { diff --git a/source/logrepl/internal/relationset.go b/source/logrepl/internal/relationset.go index ccc718c..ed0cffb 100644 --- a/source/logrepl/internal/relationset.go +++ b/source/logrepl/internal/relationset.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" + "github.com/conduitio/conduit-connector-postgres/internal" "github.com/conduitio/conduit-connector-postgres/source/types" "github.com/jackc/pglogrepl" "github.com/jackc/pgx/v5/pgtype" @@ -50,7 +51,7 @@ func (rs *RelationSet) Get(id uint32) (*pglogrepl.RelationMessage, error) { return msg, nil } -func (rs *RelationSet) Values(id uint32, row *pglogrepl.TupleData) (map[string]any, error) { +func (rs *RelationSet) Values(id uint32, row *pglogrepl.TupleData, tableInfo *internal.TableInfo) (map[string]any, error) { if row == nil { return nil, errors.New("no tuple data") } @@ -65,7 +66,7 @@ func (rs *RelationSet) Values(id uint32, row *pglogrepl.TupleData) (map[string]a // assert same number of row and rel columns for i, tuple := range row.Columns { col := rel.Columns[i] - v, decodeErr := rs.decodeValue(col, tuple.Data) + v, decodeErr := rs.decodeValue(col, tableInfo.Columns[col.Name], tuple.Data) if decodeErr != nil { return nil, fmt.Errorf("failed to decode value for column %q: %w", col.Name, err) } @@ -84,7 +85,7 @@ func (rs *RelationSet) oidToCodec(id uint32) pgtype.Codec { return dt.Codec } -func (rs *RelationSet) decodeValue(col *pglogrepl.RelationMessageColumn, data []byte) (any, error) { +func (rs *RelationSet) decodeValue(col *pglogrepl.RelationMessageColumn, colInfo *internal.ColumnInfo, data []byte) (any, error) { decoder := rs.oidToCodec(col.DataType) // This workaround is due to an issue in pgx v5.7.1. // Namely, that version introduces an XML codec @@ -105,7 +106,7 @@ func (rs *RelationSet) decodeValue(col *pglogrepl.RelationMessageColumn, data [] return nil, fmt.Errorf("failed to decode value of pgtype %v: %w", col.DataType, err) } - v, err := types.Format(col.DataType, val) + v, err := types.Format(col.DataType, val, colInfo.IsNotNull) if err != nil { return nil, fmt.Errorf("failed to format column %q type %T: %w", col.Name, val, err) } diff --git a/source/logrepl/internal/relationset_test.go b/source/logrepl/internal/relationset_test.go index c120bfe..b26a582 100644 --- a/source/logrepl/internal/relationset_test.go +++ b/source/logrepl/internal/relationset_test.go @@ -23,6 +23,7 @@ import ( "testing" "time" + "github.com/conduitio/conduit-connector-postgres/internal" "github.com/conduitio/conduit-connector-postgres/test" "github.com/google/go-cmp/cmp" "github.com/jackc/pglogrepl" @@ -49,7 +50,6 @@ func TestRelationSetAllTypes(t *testing.T) { is := is.New(t) pool := test.ConnectPool(ctx, t, test.RepmgrConnString) - table := setupTableAllTypes(ctx, t, pool) _, messages := setupSubscription(ctx, t, pool, table) insertRowAllTypes(ctx, t, pool, table) @@ -77,6 +77,10 @@ func TestRelationSetAllTypes(t *testing.T) { break } + tableInfo := internal.NewTableInfoFetcher(pool) + err := tableInfo.Refresh(ctx, table) + is.NoErr(err) + rs := NewRelationSet() rs.Add(rel) @@ -87,7 +91,7 @@ func TestRelationSetAllTypes(t *testing.T) { t.Run("with builtin plugin", func(t *testing.T) { is := is.New(t) - values, err := rs.Values(ins.RelationID, ins.Tuple) + values, err := rs.Values(ins.RelationID, ins.Tuple, tableInfo.GetTable(table)) is.NoErr(err) isValuesAllTypes(is, values) }) @@ -95,7 +99,7 @@ func TestRelationSetAllTypes(t *testing.T) { t.Run("with standalone plugin", func(t *testing.T) { is := is.New(t) - values, err := rs.Values(ins.RelationID, ins.Tuple) + values, err := rs.Values(ins.RelationID, ins.Tuple, tableInfo.GetTable(table)) is.NoErr(err) isValuesAllTypesStandalone(is, values) }) @@ -105,52 +109,51 @@ func TestRelationSetAllTypes(t *testing.T) { func setupTableAllTypes(ctx context.Context, t *testing.T, conn test.Querier) string { is := is.New(t) table := test.RandomIdentifier(t) - query := ` - CREATE TABLE %s ( - id bigserial PRIMARY KEY, - col_bit bit(8), - col_varbit varbit(8), - col_boolean boolean, - col_box box, - col_bytea bytea, - col_char char(3), - col_varchar varchar(10), - col_cidr cidr, - col_circle circle, - col_date date, - col_float4 float4, - col_float8 float8, - col_inet inet, - col_int2 int2, - col_int4 int4, - col_int8 int8, - col_interval interval, - col_json json, - col_jsonb jsonb, - col_line line, - col_lseg lseg, - col_macaddr macaddr, - col_macaddr8 macaddr8, - col_money money, - col_numeric numeric(8,2), - col_path path, - col_pg_lsn pg_lsn, - col_pg_snapshot pg_snapshot, - col_point point, - col_polygon polygon, - col_serial2 serial2, - col_serial4 serial4, - col_serial8 serial8, - col_text text, - col_time time, - col_timetz timetz, - col_timestamp timestamp, - col_timestamptz timestamptz, - col_tsquery tsquery, - col_tsvector tsvector, - col_uuid uuid, - col_xml xml - )` + query := `CREATE TABLE %s ( + id bigserial PRIMARY KEY, + col_bit bit(8) NOT NULL, + col_varbit varbit(8) NOT NULL, + col_boolean boolean NOT NULL, + col_box box NOT NULL, + col_bytea bytea NOT NULL, + col_char char(3) NOT NULL, + col_varchar varchar(10) NOT NULL, + col_cidr cidr NOT NULL, + col_circle circle NOT NULL, + col_date date NOT NULL, + col_float4 float4 NOT NULL, + col_float8 float8 NOT NULL, + col_inet inet NOT NULL, + col_int2 int2 NOT NULL, + col_int4 int4 NOT NULL, + col_int8 int8 NOT NULL, + col_interval interval NOT NULL, + col_json json NOT NULL, + col_jsonb jsonb NOT NULL, + col_line line NOT NULL, + col_lseg lseg NOT NULL, + col_macaddr macaddr NOT NULL, + col_macaddr8 macaddr8 NOT NULL, + col_money money NOT NULL, + col_numeric numeric(8,2) NOT NULL, + col_path path NOT NULL, + col_pg_lsn pg_lsn NOT NULL, + col_pg_snapshot pg_snapshot NOT NULL, + col_point point NOT NULL, + col_polygon polygon NOT NULL, + col_serial2 serial2 NOT NULL, + col_serial4 serial4 NOT NULL, + col_serial8 serial8 NOT NULL, + col_text text NOT NULL, + col_time time NOT NULL, + col_timetz timetz NOT NULL, + col_timestamp timestamp NOT NULL, + col_timestamptz timestamptz NOT NULL, + col_tsquery tsquery NOT NULL, + col_tsvector tsvector NOT NULL, + col_uuid uuid NOT NULL, + col_xml xml NOT NULL +)` query = fmt.Sprintf(query, table) _, err := conn.Exec(ctx, query) is.NoErr(err) @@ -353,9 +356,7 @@ func isValuesAllTypes(is *is.I, got map[string]any) { cmp.Comparer(func(x, y netip.Prefix) bool { return x.String() == y.String() }), - cmp.Comparer(func(x, y *big.Rat) bool { - return x.Cmp(y) == 0 - }), + test.BigRatComparer, )) } @@ -453,8 +454,6 @@ func isValuesAllTypesStandalone(is *is.I, got map[string]any) { cmp.Comparer(func(x, y netip.Prefix) bool { return x.String() == y.String() }), - cmp.Comparer(func(x, y *big.Rat) bool { - return x.Cmp(y) == 0 - }), + test.BigRatComparer, )) } diff --git a/source/logrepl/internal/subscription_test.go b/source/logrepl/internal/subscription_test.go index da6a587..3365afb 100644 --- a/source/logrepl/internal/subscription_test.go +++ b/source/logrepl/internal/subscription_test.go @@ -42,8 +42,8 @@ func TestSubscription_WithRepmgr(t *testing.T) { var ( ctx = test.Context(t) pool = test.ConnectPool(ctx, t, test.RepmgrConnString) - table1 = test.SetupTestTable(ctx, t, pool) - table2 = test.SetupTestTable(ctx, t, pool) + table1 = test.SetupTable(ctx, t, pool) + table2 = test.SetupTable(ctx, t, pool) ) sub, messages := setupSubscription(ctx, t, pool, table1, table2) @@ -64,9 +64,11 @@ func TestSubscription_WithRepmgr(t *testing.T) { t.Run("first insert table1", func(t *testing.T) { is := is.New(t) - query := `INSERT INTO %s (id, column1, column2, column3) - VALUES (6, 'bizz', 456, false)` - _, err := pool.Exec(ctx, fmt.Sprintf(query, table1)) + query := fmt.Sprintf( + `INSERT INTO %s (id, key, column1, column2, column3, column4, "UppercaseColumn1") + VALUES (6, '6', 'bizz', 456, false, 12.3, 61)`, + table1) + _, err := pool.Exec(ctx, query) is.NoErr(err) _ = fetchAndAssertMessageTypes( @@ -82,9 +84,11 @@ func TestSubscription_WithRepmgr(t *testing.T) { t.Run("second insert table1", func(t *testing.T) { is := is.New(t) - query := `INSERT INTO %s (id, column1, column2, column3) - VALUES (7, 'bizz', 456, false)` - _, err := pool.Exec(ctx, fmt.Sprintf(query, table1)) + query := fmt.Sprintf( + `INSERT INTO %s (id, key, column1, column2, column3, column4, "UppercaseColumn1") + VALUES (7, '7', 'bizz', 456, false, 12.3, 61)`, + table1) + _, err := pool.Exec(ctx, query) is.NoErr(err) _ = fetchAndAssertMessageTypes( @@ -156,15 +160,17 @@ func TestSubscription_ClosedContext(t *testing.T) { var ( is = is.New(t) pool = test.ConnectPool(ctx, t, test.RepmgrConnString) - table = test.SetupTestTable(ctx, t, pool) + table = test.SetupTable(ctx, t, pool) ) sub, messages := setupSubscription(ctx, t, pool, table) // insert to get new messages into publication - query := `INSERT INTO %s (id, column1, column2, column3) - VALUES (6, 'bizz', 456, false)` - _, err := pool.Exec(ctx, fmt.Sprintf(query, table)) + query := fmt.Sprintf( + `INSERT INTO %s (id, key, column1, column2, column3, column4, "UppercaseColumn1") + VALUES (6, '6', 'bizz', 456, false, 12.3, 61)`, + table) + _, err := pool.Exec(ctx, query) is.NoErr(err) cancel() diff --git a/source/schema/avro.go b/source/schema/avro.go index 0026a2b..71c06a4 100644 --- a/source/schema/avro.go +++ b/source/schema/avro.go @@ -19,6 +19,7 @@ import ( "fmt" "slices" + "github.com/conduitio/conduit-connector-postgres/internal" "github.com/hamba/avro/v2" "github.com/jackc/pglogrepl" "github.com/jackc/pgx/v5/pgconn" @@ -65,7 +66,7 @@ type avroExtractor struct { // ExtractLogrepl extracts an Avro schema from the given pglogrepl.RelationMessage. // If `fieldNames` are specified, then only the given fields will be included in the schema. -func (a avroExtractor) ExtractLogrepl(schemaName string, rel *pglogrepl.RelationMessage, fieldNames ...string) (*avro.RecordSchema, error) { +func (a *avroExtractor) ExtractLogrepl(schemaName string, rel *pglogrepl.RelationMessage, tableInfo *internal.TableInfo, fieldNames ...string) (*avro.RecordSchema, error) { var fields []pgconn.FieldDescription for i := range rel.Columns { @@ -76,12 +77,12 @@ func (a avroExtractor) ExtractLogrepl(schemaName string, rel *pglogrepl.Relation }) } - return a.Extract(schemaName, fields, fieldNames...) + return a.Extract(schemaName, tableInfo, fields, fieldNames...) } // Extract extracts an Avro schema from the given Postgres field descriptions. // If `fieldNames` are specified, then only the given fields will be included in the schema. -func (a *avroExtractor) Extract(schemaName string, fields []pgconn.FieldDescription, fieldNames ...string) (*avro.RecordSchema, error) { +func (a *avroExtractor) Extract(schemaName string, tableInfo *internal.TableInfo, fields []pgconn.FieldDescription, fieldNames ...string) (*avro.RecordSchema, error) { var avroFields []*avro.Field for _, f := range fields { @@ -94,7 +95,7 @@ func (a *avroExtractor) Extract(schemaName string, fields []pgconn.FieldDescript return nil, fmt.Errorf("field %q with OID %d cannot be resolved", f.Name, f.DataTypeOID) } - s, err := a.extractType(t, f.TypeModifier) + s, err := a.extractType(t, f.TypeModifier, tableInfo.Columns[f.Name].IsNotNull) if err != nil { return nil, err } @@ -119,7 +120,25 @@ func (a *avroExtractor) Extract(schemaName string, fields []pgconn.FieldDescript return sch, nil } -func (a *avroExtractor) extractType(t *pgtype.Type, typeMod int32) (avro.Schema, error) { +func (a *avroExtractor) extractType(t *pgtype.Type, typeMod int32, notNull bool) (avro.Schema, error) { + baseType, err := a.extractBaseType(t, typeMod) + if err != nil { + return nil, err + } + + if !notNull { + schema, err := avro.NewUnionSchema([]avro.Schema{avro.NewNullSchema(), baseType}) + if err != nil { + return nil, fmt.Errorf("failed to create avro union schema for nullable type %v: %w", baseType, err) + } + + return schema, nil + } + + return baseType, nil +} + +func (a *avroExtractor) extractBaseType(t *pgtype.Type, typeMod int32) (avro.Schema, error) { if ps, ok := a.avroMap[t.Name]; ok { return ps, nil } diff --git a/source/schema/avro_integration_test.go b/source/schema/avro_integration_test.go new file mode 100644 index 0000000..36a8003 --- /dev/null +++ b/source/schema/avro_integration_test.go @@ -0,0 +1,560 @@ +// Copyright © 2024 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package schema + +import ( + "cmp" + "context" + "fmt" + "math/big" + "reflect" + "slices" + "strings" + "testing" + "time" + + "github.com/conduitio/conduit-connector-postgres/internal" + "github.com/conduitio/conduit-connector-postgres/source/cpool" + "github.com/conduitio/conduit-connector-postgres/source/types" + "github.com/conduitio/conduit-connector-postgres/test" + "github.com/hamba/avro/v2" + "github.com/jackc/pgx/v5/pgconn" + "github.com/matryer/is" +) + +func Test_AvroExtract(t *testing.T) { + ctx := test.Context(t) + is := is.New(t) + + c := test.ConnectSimple(ctx, t, test.RegularConnString) + connPool, err := cpool.New(ctx, test.RegularConnString) + is.NoErr(err) + + table := setupAvroTestTable(ctx, t, c) + tableInfoFetcher := internal.NewTableInfoFetcher(connPool) + err = tableInfoFetcher.Refresh(ctx, table) + is.NoErr(err) + + insertAvroTestRow(ctx, t, c, table) + + rows, err := c.Query(ctx, "SELECT * FROM "+table) + is.NoErr(err) + defer rows.Close() + + rows.Next() + + values, err := rows.Values() + is.NoErr(err) + + fields := rows.FieldDescriptions() + + schemaExtracted, err := Avro.Extract(table, tableInfoFetcher.GetTable(table), fields) + is.NoErr(err) + + t.Run("schema is parsable", func(t *testing.T) { + is := is.New(t) + is.NoErr(err) + is.Equal(schemaExtracted.String(), avroTestSchema(t, table).String()) + + _, err = avro.Parse(schemaExtracted.String()) + is.NoErr(err) + }) + + t.Run("serde row", func(t *testing.T) { + is := is.New(t) + + row := avrolizeMap(fields, values) + + sch, err := avro.Parse(schemaExtracted.String()) + is.NoErr(err) + + data, err := avro.Marshal(sch, row) + is.NoErr(err) + is.True(len(data) > 0) + + decoded := make(map[string]any) + is.NoErr(avro.Unmarshal(sch, data, &decoded)) + + is.Equal(len(decoded), len(row)) + + // Compare all fields + compareValue(is, row, decoded, "col_bytea") + compareValue(is, row, decoded, "col_bytea_not_null") + compareValue(is, row, decoded, "col_varchar") + compareValue(is, row, decoded, "col_varchar_not_null") + compareValue(is, row, decoded, "col_date") + compareValue(is, row, decoded, "col_date_not_null") + compareValue(is, row, decoded, "col_float4") + compareValue(is, row, decoded, "col_float4_not_null") + compareValue(is, row, decoded, "col_float8") + compareValue(is, row, decoded, "col_float8_not_null") + + compareIntValue(is, row, decoded, "col_int2") + compareIntValue(is, row, decoded, "col_int2_not_null") + compareIntValue(is, row, decoded, "col_int4") + compareIntValue(is, row, decoded, "col_int4_not_null") + + compareValue(is, row, decoded, "col_int8") + compareValue(is, row, decoded, "col_int8_not_null") + + compareNumericValue(is, row, decoded, "col_numeric") + compareNumericValue(is, row, decoded, "col_numeric_not_null") + + compareValue(is, row, decoded, "col_text") + compareValue(is, row, decoded, "col_text_not_null") + + compareTimestampValue(is, row, decoded, "col_timestamp") + compareTimestampValue(is, row, decoded, "col_timestamp_not_null") + compareTimestampValue(is, row, decoded, "col_timestamptz") + compareTimestampValue(is, row, decoded, "col_timestamptz_not_null") + + compareValue(is, row, decoded, "col_uuid") + compareValue(is, row, decoded, "col_uuid_not_null") + compareValue(is, row, decoded, "col_json") + compareValue(is, row, decoded, "col_json_not_null") + compareValue(is, row, decoded, "col_jsonb") + compareValue(is, row, decoded, "col_jsonb_not_null") + compareValue(is, row, decoded, "col_bool") + compareValue(is, row, decoded, "col_bool_not_null") + + // Serial types are integers, so use compareIntValue + compareIntValue(is, row, decoded, "col_serial") + compareIntValue(is, row, decoded, "col_serial_not_null") + compareIntValue(is, row, decoded, "col_smallserial") + compareIntValue(is, row, decoded, "col_smallserial_not_null") + compareValue(is, row, decoded, "col_bigserial") + compareValue(is, row, decoded, "col_bigserial_not_null") + }) +} + +// Extracted comparison functions +func compareValue(is *is.I, wantMap, gotMap map[string]any, key string) { + is.Helper() + + want := wantMap[key] + got := gotMap[key] + + if want == nil { + is.Equal(nil, got) + return + } + + // If row value is a pointer, dereference it + wantReflect := reflect.ValueOf(want) + if wantReflect.Kind() == reflect.Ptr { + if wantReflect.IsNil() { + is.Equal(nil, got) + return + } + want = wantReflect.Elem().Interface() + } + + is.Equal(want, got) +} + +func compareIntValue(is *is.I, wantMap, gotMap map[string]any, key string) { + is.Helper() + + want := wantMap[key] + got := gotMap[key] + + if want == nil { + is.Equal(nil, got) + return + } + + switch v := want.(type) { + case *int16: + case *int32: + is.Equal(int(*v), got) + case int16: + case int32: + is.Equal(int(v), got) + default: + is.Equal(want, got) + } +} + +func compareNumericValue(is *is.I, wantMap, gotMap map[string]any, key string) { + is.Helper() + + want := wantMap[key] + got := gotMap[key] + + if want == nil { + is.Equal(nil, got) + return + } + + numRow, ok := want.(*big.Rat) + if !ok || numRow == nil { + is.Equal(nil, got) + return + } + + numDecoded := got.(*big.Rat) + is.Equal(numRow.RatString(), numDecoded.RatString()) +} + +func compareTimestampValue(is *is.I, wantMap, gotMap map[string]any, key string) { + is.Helper() + + want := wantMap[key] + got := gotMap[key] + + if want == nil { + is.Equal(nil, got) + return + } + + var wantTS time.Time + switch v := want.(type) { + case *time.Time: + wantTS = *v + case time.Time: + wantTS = v + } + + var gotTS time.Time + switch v := got.(type) { + case map[string]interface{}: + gotTS = got.(map[string]interface{})["long.local-timestamp-micros"].(time.Time) + case time.Time: + gotTS = v + } + + is.Equal(wantTS.UTC().String(), gotTS.UTC().String()) +} + +func setupAvroTestTable(ctx context.Context, t *testing.T, conn test.Querier) string { + is := is.New(t) + table := test.RandomIdentifier(t) + + query := ` + CREATE TABLE %s ( + id bigserial PRIMARY KEY, + col_bytea bytea, + col_bytea_not_null bytea NOT NULL, + col_varchar varchar(10), + col_varchar_not_null varchar(10) NOT NULL, + col_date date, + col_date_not_null date NOT NULL, + col_float4 float4, + col_float4_not_null float4 NOT NULL, + col_float8 float8, + col_float8_not_null float8 NOT NULL, + col_int2 int2, + col_int2_not_null int2 NOT NULL, + col_int4 int4, + col_int4_not_null int4 NOT NULL, + col_int8 int8, + col_int8_not_null int8 NOT NULL, + col_numeric numeric(8,2), + col_numeric_not_null numeric(8,2) NOT NULL, + col_text text, + col_text_not_null text NOT NULL, + col_timestamp timestamp, + col_timestamp_not_null timestamp NOT NULL, + col_timestamptz timestamptz, + col_timestamptz_not_null timestamptz NOT NULL, + col_uuid uuid, + col_uuid_not_null uuid NOT NULL, + col_json json, + col_json_not_null json NOT NULL, + col_jsonb jsonb, + col_jsonb_not_null jsonb NOT NULL, + col_bool bool, + col_bool_not_null bool NOT NULL, + col_serial serial, + col_serial_not_null serial NOT NULL, + col_smallserial smallserial, + col_smallserial_not_null smallserial NOT NULL, + col_bigserial bigserial, + col_bigserial_not_null bigserial NOT NULL + )` + query = fmt.Sprintf(query, table) + _, err := conn.Exec(ctx, query) + is.NoErr(err) + + return table +} + +func insertAvroTestRow(ctx context.Context, t *testing.T, conn test.Querier, table string) { + is := is.New(t) + query := ` + INSERT INTO %s ( + col_bytea, + col_bytea_not_null, + col_varchar, + col_varchar_not_null, + col_date, + col_date_not_null, + col_float4, + col_float4_not_null, + col_float8, + col_float8_not_null, + col_int2, + col_int2_not_null, + col_int4, + col_int4_not_null, + col_int8, + col_int8_not_null, + col_numeric, + col_numeric_not_null, + col_text, + col_text_not_null, + col_timestamp, + col_timestamp_not_null, + col_timestamptz, + col_timestamptz_not_null, + col_uuid, + col_uuid_not_null, + col_json, + col_json_not_null, + col_jsonb, + col_jsonb_not_null, + col_bool, + col_bool_not_null, + col_serial, + col_serial_not_null, + col_smallserial, + col_smallserial_not_null, + col_bigserial, + col_bigserial_not_null + ) VALUES ( + '\x07', -- col_bytea + '\x08', -- col_bytea_not_null + '9', -- col_varchar + '10', -- col_varchar_not_null + '2022-03-14', -- col_date + '2022-03-15', -- col_date_not_null + 15, -- col_float4 + 16, -- col_float4_not_null + 16.16, -- col_float8 + 17.17, -- col_float8_not_null + 32767, -- col_int2 + 32766, -- col_int2_not_null + 2147483647, -- col_int4 + 2147483646, -- col_int4_not_null + 9223372036854775807, -- col_int8 + 9223372036854775806, -- col_int8_not_null + '292929.29', -- col_numeric + '292928.28', -- col_numeric_not_null + 'foo bar baz', -- col_text + 'foo bar baz not null', -- col_text_not_null + '2022-03-14 15:16:17', -- col_timestamp + '2022-03-14 15:16:18', -- col_timestamp_not_null + '2022-03-14 15:16:17-08', -- col_timestamptz + '2022-03-14 15:16:18-08', -- col_timestamptz_not_null + 'bd94ee0b-564f-4088-bf4e-8d5e626caf66', -- col_uuid + 'bd94ee0b-564f-4088-bf4e-8d5e626caf67', -- col_uuid_not_null + '{"key": "value"}', -- col_json + '{"key": "value_not_null"}', -- col_json_not_null + '{"key": "value"}', -- col_jsonb + '{"key": "value_not_null"}', -- col_jsonb_not_null + true, -- col_bool + false, -- col_bool_not_null + 100, -- col_serial + 101, -- col_serial_not_null + 200, -- col_smallserial + 201, -- col_smallserial_not_null + 300, -- col_bigserial + 301 -- col_bigserial_not_null + )` + query = fmt.Sprintf(query, table) + _, err := conn.Exec(ctx, query) + is.NoErr(err) +} + +func avroTestSchema(t *testing.T, table string) avro.Schema { + is := is.New(t) + + fields := []*avro.Field{ + // Primary key - bigserial (not null) + assert(avro.NewField("id", avro.NewPrimitiveSchema(avro.Long, nil))), + + // bytea fields + assert(avro.NewField("col_bytea", assert(avro.NewUnionSchema([]avro.Schema{ + avro.NewPrimitiveSchema(avro.Null, nil), + avro.NewPrimitiveSchema(avro.Bytes, nil), + })))), + assert(avro.NewField("col_bytea_not_null", avro.NewPrimitiveSchema(avro.Bytes, nil))), + + // varchar fields + assert(avro.NewField("col_varchar", assert(avro.NewUnionSchema([]avro.Schema{ + avro.NewPrimitiveSchema(avro.Null, nil), + avro.NewPrimitiveSchema(avro.String, nil), + })))), + assert(avro.NewField("col_varchar_not_null", avro.NewPrimitiveSchema(avro.String, nil))), + + // date fields + assert(avro.NewField("col_date", assert(avro.NewUnionSchema([]avro.Schema{ + avro.NewPrimitiveSchema(avro.Null, nil), + avro.NewPrimitiveSchema(avro.Int, avro.NewPrimitiveLogicalSchema(avro.Date)), + })))), + assert(avro.NewField("col_date_not_null", avro.NewPrimitiveSchema( + avro.Int, + avro.NewPrimitiveLogicalSchema(avro.Date), + ))), + + // float4 fields + assert(avro.NewField("col_float4", assert(avro.NewUnionSchema([]avro.Schema{ + avro.NewPrimitiveSchema(avro.Null, nil), + avro.NewPrimitiveSchema(avro.Float, nil), + })))), + assert(avro.NewField("col_float4_not_null", avro.NewPrimitiveSchema(avro.Float, nil))), + + // float8 fields + assert(avro.NewField("col_float8", assert(avro.NewUnionSchema([]avro.Schema{ + avro.NewPrimitiveSchema(avro.Null, nil), + avro.NewPrimitiveSchema(avro.Double, nil), + })))), + assert(avro.NewField("col_float8_not_null", avro.NewPrimitiveSchema(avro.Double, nil))), + + // int2 fields + assert(avro.NewField("col_int2", assert(avro.NewUnionSchema([]avro.Schema{ + avro.NewPrimitiveSchema(avro.Null, nil), + avro.NewPrimitiveSchema(avro.Int, nil), + })))), + assert(avro.NewField("col_int2_not_null", avro.NewPrimitiveSchema(avro.Int, nil))), + + // int4 fields + assert(avro.NewField("col_int4", assert(avro.NewUnionSchema([]avro.Schema{ + avro.NewPrimitiveSchema(avro.Null, nil), + avro.NewPrimitiveSchema(avro.Int, nil), + })))), + assert(avro.NewField("col_int4_not_null", avro.NewPrimitiveSchema(avro.Int, nil))), + + // int8 fields + assert(avro.NewField("col_int8", assert(avro.NewUnionSchema([]avro.Schema{ + avro.NewPrimitiveSchema(avro.Null, nil), + avro.NewPrimitiveSchema(avro.Long, nil), + })))), + assert(avro.NewField("col_int8_not_null", avro.NewPrimitiveSchema(avro.Long, nil))), + + // numeric fields + assert(avro.NewField("col_numeric", assert(avro.NewUnionSchema([]avro.Schema{ + avro.NewPrimitiveSchema(avro.Null, nil), + avro.NewPrimitiveSchema(avro.Bytes, avro.NewDecimalLogicalSchema(8, 2)), + })))), + assert(avro.NewField("col_numeric_not_null", avro.NewPrimitiveSchema( + avro.Bytes, + avro.NewDecimalLogicalSchema(8, 2), + ))), + + // text fields + assert(avro.NewField("col_text", assert(avro.NewUnionSchema([]avro.Schema{ + avro.NewPrimitiveSchema(avro.Null, nil), + avro.NewPrimitiveSchema(avro.String, nil), + })))), + assert(avro.NewField("col_text_not_null", avro.NewPrimitiveSchema(avro.String, nil))), + + // timestamp fields + assert(avro.NewField("col_timestamp", assert(avro.NewUnionSchema([]avro.Schema{ + avro.NewPrimitiveSchema(avro.Null, nil), + avro.NewPrimitiveSchema(avro.Long, avro.NewPrimitiveLogicalSchema(avro.LocalTimestampMicros)), + })))), + assert(avro.NewField("col_timestamp_not_null", avro.NewPrimitiveSchema( + avro.Long, + avro.NewPrimitiveLogicalSchema(avro.LocalTimestampMicros), + ))), + + // timestamptz fields + assert(avro.NewField("col_timestamptz", assert(avro.NewUnionSchema([]avro.Schema{ + avro.NewPrimitiveSchema(avro.Null, nil), + avro.NewPrimitiveSchema(avro.Long, avro.NewPrimitiveLogicalSchema(avro.TimestampMicros)), + })))), + assert(avro.NewField("col_timestamptz_not_null", avro.NewPrimitiveSchema( + avro.Long, + avro.NewPrimitiveLogicalSchema(avro.TimestampMicros), + ))), + + // uuid fields + assert(avro.NewField("col_uuid", assert(avro.NewUnionSchema([]avro.Schema{ + avro.NewPrimitiveSchema(avro.Null, nil), + avro.NewPrimitiveSchema(avro.String, avro.NewPrimitiveLogicalSchema(avro.UUID)), + })))), + assert(avro.NewField("col_uuid_not_null", avro.NewPrimitiveSchema( + avro.String, + avro.NewPrimitiveLogicalSchema(avro.UUID), + ))), + + // json fields (represented as bytes in Avro) + assert(avro.NewField("col_json", assert(avro.NewUnionSchema([]avro.Schema{ + avro.NewPrimitiveSchema(avro.Null, nil), + avro.NewPrimitiveSchema(avro.Bytes, nil), + })))), + assert(avro.NewField("col_json_not_null", avro.NewPrimitiveSchema(avro.Bytes, nil))), + + // jsonb fields (represented as bytes in Avro) + assert(avro.NewField("col_jsonb", assert(avro.NewUnionSchema([]avro.Schema{ + avro.NewPrimitiveSchema(avro.Null, nil), + avro.NewPrimitiveSchema(avro.Bytes, nil), + })))), + assert(avro.NewField("col_jsonb_not_null", avro.NewPrimitiveSchema(avro.Bytes, nil))), + + // bool fields + assert(avro.NewField("col_bool", assert(avro.NewUnionSchema([]avro.Schema{ + avro.NewPrimitiveSchema(avro.Null, nil), + avro.NewPrimitiveSchema(avro.Boolean, nil), + })))), + assert(avro.NewField("col_bool_not_null", avro.NewPrimitiveSchema(avro.Boolean, nil))), + + // serial fields (represented as int in Avro) + assert(avro.NewField("col_serial", avro.NewPrimitiveSchema(avro.Int, nil))), + assert(avro.NewField("col_serial_not_null", avro.NewPrimitiveSchema(avro.Int, nil))), + + // smallserial fields (represented as int in Avro) + assert(avro.NewField("col_smallserial", avro.NewPrimitiveSchema(avro.Int, nil))), + assert(avro.NewField("col_smallserial_not_null", avro.NewPrimitiveSchema(avro.Int, nil))), + + // bigserial fields (represented as long in Avro) + assert(avro.NewField("col_bigserial", avro.NewPrimitiveSchema(avro.Long, nil))), + assert(avro.NewField("col_bigserial_not_null", avro.NewPrimitiveSchema(avro.Long, nil))), + } + + slices.SortFunc(fields, func(a, b *avro.Field) int { + return cmp.Compare(a.Name(), b.Name()) + }) + + s, err := avro.NewRecordSchema(table, "", fields) + is.NoErr(err) + + return s +} + +func avrolizeMap(fields []pgconn.FieldDescription, values []any) map[string]any { + row := make(map[string]any) + + for i, f := range fields { + isNotNull := f.Name == "id" || + f.Name == "col_bigserial" || + f.Name == "col_serial" || + f.Name == "col_smallserial" || + strings.HasSuffix(f.Name, "_not_null") + + row[f.Name] = assert(types.Format(f.DataTypeOID, values[i], isNotNull)) + } + + return row +} + +func assert[T any](a T, err error) T { + if err != nil { + panic(err) + } + + return a +} diff --git a/source/schema/avro_test.go b/source/schema/avro_test.go deleted file mode 100644 index 448e2a9..0000000 --- a/source/schema/avro_test.go +++ /dev/null @@ -1,249 +0,0 @@ -// Copyright © 2024 Meroxa, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package schema - -import ( - "cmp" - "context" - "fmt" - "math/big" - "slices" - "testing" - "time" - - "github.com/conduitio/conduit-connector-postgres/source/types" - "github.com/conduitio/conduit-connector-postgres/test" - "github.com/hamba/avro/v2" - "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgtype" - "github.com/matryer/is" -) - -func Test_AvroExtract(t *testing.T) { - ctx := test.Context(t) - is := is.New(t) - - c := test.ConnectSimple(ctx, t, test.RegularConnString) - table := setupAvroTestTable(ctx, t, c) - insertAvroTestRow(ctx, t, c, table) - - rows, err := c.Query(ctx, "SELECT * FROM "+table) - is.NoErr(err) - defer rows.Close() - - rows.Next() - - values, err := rows.Values() - is.NoErr(err) - - fields := rows.FieldDescriptions() - - sch, err := Avro.Extract(table, fields) - is.NoErr(err) - - t.Run("schema is parsable", func(t *testing.T) { - is := is.New(t) - is.NoErr(err) - is.Equal(sch, avroTestSchema(t, table)) - - _, err = avro.Parse(sch.String()) - is.NoErr(err) - }) - - t.Run("serde row", func(t *testing.T) { - is := is.New(t) - - row := avrolizeMap(fields, values) - - sch, err := avro.Parse(sch.String()) - is.NoErr(err) - - data, err := avro.Marshal(sch, row) - is.NoErr(err) - is.True(len(data) > 0) - - decoded := make(map[string]any) - is.NoErr(avro.Unmarshal(sch, data, &decoded)) - - is.Equal(len(decoded), len(row)) - is.Equal(row["col_boolean"], decoded["col_boolean"]) - is.Equal(row["col_bytea"], decoded["col_bytea"]) - is.Equal(row["col_varchar"], decoded["col_varchar"]) - is.Equal(row["col_date"], decoded["col_date"]) - is.Equal(row["col_float4"], decoded["col_float4"]) - is.Equal(row["col_float8"], decoded["col_float8"]) - - colInt2 := int(row["col_int2"].(int16)) - is.Equal(colInt2, decoded["col_int2"]) - - colInt4 := int(row["col_int4"].(int32)) - is.Equal(colInt4, decoded["col_int4"]) - - is.Equal(row["col_int8"], decoded["col_int8"]) - - numRow := row["col_numeric"].(*big.Rat) - numDecoded := decoded["col_numeric"].(*big.Rat) - is.Equal(numRow.RatString(), numDecoded.RatString()) - - is.Equal(row["col_text"], decoded["col_text"]) - - rowTS, colTS := row["col_timestamp"].(time.Time), decoded["col_timestamp"].(time.Time) - is.Equal(rowTS.UTC().String(), colTS.UTC().String()) - - rowTSTZ, colTSTZ := row["col_timestamptz"].(time.Time), decoded["col_timestamptz"].(time.Time) - is.Equal(rowTSTZ.UTC().String(), colTSTZ.UTC().String()) - - is.Equal(row["col_uuid"], decoded["col_uuid"]) - }) -} - -func setupAvroTestTable(ctx context.Context, t *testing.T, conn test.Querier) string { - is := is.New(t) - table := test.RandomIdentifier(t) - - query := ` - CREATE TABLE %s ( - col_boolean boolean, - col_bytea bytea, - col_varchar varchar(10), - col_date date, - col_float4 float4, - col_float8 float8, - col_int2 int2, - col_int4 int4, - col_int8 int8, - col_numeric numeric(8,2), - col_text text, - col_timestamp timestamp, - col_timestamptz timestamptz, - col_uuid uuid - )` - query = fmt.Sprintf(query, table) - _, err := conn.Exec(ctx, query) - is.NoErr(err) - - return table -} - -func insertAvroTestRow(ctx context.Context, t *testing.T, conn test.Querier, table string) { - is := is.New(t) - query := ` - INSERT INTO %s ( - col_boolean, - col_bytea, - col_varchar, - col_date, - col_float4, - col_float8, - col_int2, - col_int4, - col_int8, - col_numeric, - col_text, - col_timestamp, - col_timestamptz, - col_uuid - ) VALUES ( - true, -- col_boolean - '\x07', -- col_bytea - '9', -- col_varchar - '2022-03-14', -- col_date - 15, -- col_float4 - 16.16, -- col_float8 - 32767, -- col_int2 - 2147483647, -- col_int4 - 9223372036854775807, -- col_int8 - '292929.29', -- col_numeric - 'foo bar baz', -- col_text - '2022-03-14 15:16:17', -- col_timestamp - '2022-03-14 15:16:17-08', -- col_timestamptz - 'bd94ee0b-564f-4088-bf4e-8d5e626caf66' -- col_uuid - )` - query = fmt.Sprintf(query, table) - _, err := conn.Exec(ctx, query) - is.NoErr(err) -} - -func avroTestSchema(t *testing.T, table string) avro.Schema { - is := is.New(t) - - fields := []*avro.Field{ - assert(avro.NewField("col_boolean", avro.NewPrimitiveSchema(avro.Boolean, nil))), - assert(avro.NewField("col_bytea", avro.NewPrimitiveSchema(avro.Bytes, nil))), - assert(avro.NewField("col_varchar", avro.NewPrimitiveSchema(avro.String, nil))), - assert(avro.NewField("col_float4", avro.NewPrimitiveSchema(avro.Float, nil))), - assert(avro.NewField("col_float8", avro.NewPrimitiveSchema(avro.Double, nil))), - assert(avro.NewField("col_int2", avro.NewPrimitiveSchema(avro.Int, nil))), - assert(avro.NewField("col_int4", avro.NewPrimitiveSchema(avro.Int, nil))), - assert(avro.NewField("col_int8", avro.NewPrimitiveSchema(avro.Long, nil))), - assert(avro.NewField("col_text", avro.NewPrimitiveSchema(avro.String, nil))), - assert(avro.NewField("col_numeric", avro.NewPrimitiveSchema( - avro.Bytes, - avro.NewDecimalLogicalSchema(8, 2), - ))), - assert(avro.NewField("col_date", avro.NewPrimitiveSchema( - avro.Int, - avro.NewPrimitiveLogicalSchema(avro.Date), - ))), - assert(avro.NewField("col_timestamp", avro.NewPrimitiveSchema( - avro.Long, - avro.NewPrimitiveLogicalSchema(avro.LocalTimestampMicros), - ))), - assert(avro.NewField("col_timestamptz", avro.NewPrimitiveSchema( - avro.Long, - avro.NewPrimitiveLogicalSchema(avro.TimestampMicros), - ))), - assert(avro.NewField("col_uuid", avro.NewPrimitiveSchema( - avro.String, - avro.NewPrimitiveLogicalSchema(avro.UUID), - ))), - } - - slices.SortFunc(fields, func(a, b *avro.Field) int { - return cmp.Compare(a.Name(), b.Name()) - }) - - s, err := avro.NewRecordSchema(table, "", fields) - is.NoErr(err) - - return s -} - -func avrolizeMap(fields []pgconn.FieldDescription, values []any) map[string]any { - row := make(map[string]any) - - for i, f := range fields { - switch f.DataTypeOID { - case pgtype.NumericOID: - n := new(big.Rat) - n.SetString(fmt.Sprint(types.Format(0, values[i]))) - row[f.Name] = n - case pgtype.UUIDOID: - row[f.Name] = fmt.Sprint(values[i]) - default: - row[f.Name] = values[i] - } - } - - return row -} - -func assert[T any](a T, err error) T { - if err != nil { - panic(err) - } - - return a -} diff --git a/source/snapshot/fetch_worker.go b/source/snapshot/fetch_worker.go index 2f52787..2a46162 100644 --- a/source/snapshot/fetch_worker.go +++ b/source/snapshot/fetch_worker.go @@ -94,20 +94,23 @@ type FetchWorker struct { db *pgxpool.Pool out chan<- []FetchData - keySchema *cschema.Schema - payloadSchema *cschema.Schema + // notNullMap maps column names to if the column is NOT NULL. + tableInfoFetcher *internal.TableInfoFetcher + keySchema *cschema.Schema - snapshotEnd int64 - lastRead int64 - cursorName string + payloadSchema *cschema.Schema + snapshotEnd int64 + lastRead int64 + cursorName string } func NewFetchWorker(db *pgxpool.Pool, out chan<- []FetchData, c FetchConfig) *FetchWorker { f := &FetchWorker{ - conf: c, - db: db, - out: out, - cursorName: "fetcher_" + strings.ReplaceAll(uuid.NewString(), "-", ""), + conf: c, + db: db, + out: out, + tableInfoFetcher: internal.NewTableInfoFetcher(db), + cursorName: "fetcher_" + strings.ReplaceAll(uuid.NewString(), "-", ""), } if f.conf.FetchSize == 0 { @@ -126,9 +129,23 @@ func NewFetchWorker(db *pgxpool.Pool, out chan<- []FetchData, c FetchConfig) *Fe return f } -// Validate will ensure the config is correct. +// Init will ensure the config is correct. // * Table and keys exist // * Key is a primary key +func (f *FetchWorker) Init(ctx context.Context) error { + err := f.Validate(ctx) + if err != nil { + return fmt.Errorf("validation failed: %w", err) + } + + err = f.tableInfoFetcher.Refresh(ctx, f.conf.Table) + if err != nil { + return fmt.Errorf("failed to refresh table info: %w", err) + } + + return nil +} + func (f *FetchWorker) Validate(ctx context.Context) error { if err := f.conf.Validate(); err != nil { return fmt.Errorf("failed to validate config: %w", err) @@ -384,9 +401,12 @@ func (f *FetchWorker) buildRecordData(fields []pgconn.FieldDescription, values [ payload = make(opencdc.StructuredData) ) + tableInfo := f.getTableInfo() for i, fd := range fields { + isNotNull := tableInfo.Columns[fd.Name].IsNotNull + if fd.Name == f.conf.Key { - k, err := types.Format(fd.DataTypeOID, values[i]) + k, err := types.Format(fd.DataTypeOID, values[i], isNotNull) if err != nil { return key, payload, fmt.Errorf("failed to format key %q: %w", f.conf.Key, err) } @@ -394,7 +414,7 @@ func (f *FetchWorker) buildRecordData(fields []pgconn.FieldDescription, values [ key[f.conf.Key] = k } - v, err := types.Format(fd.DataTypeOID, values[i]) + v, err := types.Format(fd.DataTypeOID, values[i], isNotNull) if err != nil { return key, payload, fmt.Errorf("failed to format payload field %q: %w", fd.Name, err) } @@ -489,7 +509,7 @@ func (f *FetchWorker) extractSchemas(ctx context.Context, fields []pgconn.FieldD sdk.Logger(ctx).Debug(). Msgf("extracting payload schema for %v fields in %v", len(fields), f.conf.Table) - avroPayloadSch, err := schema.Avro.Extract(f.conf.Table+"_payload", fields) + avroPayloadSch, err := schema.Avro.Extract(f.conf.Table+"_payload", f.tableInfoFetcher.GetTable(f.conf.Table), fields) if err != nil { return fmt.Errorf("failed to extract payload schema for table %v: %w", f.conf.Table, err) } @@ -509,7 +529,7 @@ func (f *FetchWorker) extractSchemas(ctx context.Context, fields []pgconn.FieldD sdk.Logger(ctx).Debug(). Msgf("extracting schema for key %v in %v", f.conf.Key, f.conf.Table) - avroKeySch, err := schema.Avro.Extract(f.conf.Table+"_key", fields, f.conf.Key) + avroKeySch, err := schema.Avro.Extract(f.conf.Table+"_key", f.getTableInfo(), fields, f.conf.Key) if err != nil { return fmt.Errorf("failed to extract key schema for table %v: %w", f.conf.Table, err) } @@ -527,3 +547,7 @@ func (f *FetchWorker) extractSchemas(ctx context.Context, fields []pgconn.FieldD return nil } + +func (f *FetchWorker) getTableInfo() *internal.TableInfo { + return f.tableInfoFetcher.GetTable(f.conf.Table) +} diff --git a/source/snapshot/fetch_worker_test.go b/source/snapshot/fetch_worker_integration_test.go similarity index 82% rename from source/snapshot/fetch_worker_test.go rename to source/snapshot/fetch_worker_integration_test.go index d558fa7..4824739 100644 --- a/source/snapshot/fetch_worker_test.go +++ b/source/snapshot/fetch_worker_integration_test.go @@ -21,14 +21,12 @@ import ( "math/big" "strings" "testing" - "time" "github.com/conduitio/conduit-commons/opencdc" "github.com/conduitio/conduit-connector-postgres/source/position" "github.com/conduitio/conduit-connector-postgres/test" "github.com/google/go-cmp/cmp" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" "github.com/matryer/is" "gopkg.in/tomb.v2" @@ -126,76 +124,81 @@ func Test_FetcherValidate(t *testing.T) { ) // uppercase table name is required to test primary key fetching - test.SetupTestTableWithName(ctx, t, pool, table) + test.SetupTableWithName(ctx, t, pool, table) t.Run("success", func(t *testing.T) { is := is.New(t) - f := FetchWorker{ - db: pool, - conf: FetchConfig{ + f := NewFetchWorker( + pool, + make(chan<- []FetchData), + FetchConfig{ Table: table, Key: "id", }, - } + ) - is.NoErr(f.Validate(ctx)) + is.NoErr(f.Init(ctx)) }) t.Run("table missing", func(t *testing.T) { is := is.New(t) - f := FetchWorker{ - db: pool, - conf: FetchConfig{ + f := NewFetchWorker( + pool, + make(chan<- []FetchData), + FetchConfig{ Table: "missing_table", Key: "id", }, - } + ) - err := f.Validate(ctx) + err := f.Init(ctx) is.True(err != nil) is.True(strings.Contains(err.Error(), `table "missing_table" does not exist`)) }) t.Run("key is wrong type", func(t *testing.T) { is := is.New(t) - f := FetchWorker{ - db: pool, - conf: FetchConfig{ + f := NewFetchWorker( + pool, + make(chan<- []FetchData), + FetchConfig{ Table: table, Key: "column3", }, - } + ) - err := f.Validate(ctx) + err := f.Init(ctx) is.True(err != nil) is.True(strings.Contains(err.Error(), `failed to validate key: key "column3" of type "boolean" is unsupported`)) }) t.Run("key is not pk", func(t *testing.T) { is := is.New(t) - f := FetchWorker{ - db: pool, - conf: FetchConfig{ + f := NewFetchWorker( + pool, + make(chan<- []FetchData), + FetchConfig{ Table: table, Key: "column2", }, - } + ) - err := f.Validate(ctx) + err := f.Init(ctx) is.NoErr(err) // no error, only a warning }) t.Run("missing key", func(t *testing.T) { is := is.New(t) - f := FetchWorker{ - db: pool, - conf: FetchConfig{ + f := NewFetchWorker( + pool, + make(chan<- []FetchData), + FetchConfig{ Table: table, Key: "missing_key", }, - } + ) - err := f.Validate(ctx) + err := f.Init(ctx) is.True(err != nil) ok := strings.Contains(err.Error(), fmt.Sprintf(`key "missing_key" not present on table %q`, table)) if !ok { @@ -210,7 +213,7 @@ func Test_FetcherRun_EmptySnapshot(t *testing.T) { is = is.New(t) ctx = test.Context(t) pool = test.ConnectPool(context.Background(), t, test.RegularConnString) - table = test.SetupEmptyTestTable(context.Background(), t, pool) + table = test.SetupEmptyTable(context.Background(), t, pool) out = make(chan []FetchData) testTomb = &tomb.Tomb{} ) @@ -242,7 +245,7 @@ func Test_FetcherRun_EmptySnapshot(t *testing.T) { func Test_FetcherRun_Initial(t *testing.T) { var ( pool = test.ConnectPool(context.Background(), t, test.RegularConnString) - table = test.SetupTestTable(context.Background(), t, pool) + table = test.SetupTable(context.Background(), t, pool) is = is.New(t) out = make(chan []FetchData) ctx = test.Context(t) @@ -258,7 +261,7 @@ func Test_FetcherRun_Initial(t *testing.T) { ctx = tt.Context(ctx) defer close(out) - if err := f.Validate(ctx); err != nil { + if err := f.Init(ctx); err != nil { return err } return f.Run(ctx) @@ -272,16 +275,11 @@ func Test_FetcherRun_Initial(t *testing.T) { is.NoErr(tt.Err()) is.Equal(len(gotFetchData), 4) - var ( - value6 = []byte(`{"foo": "bar"}`) - value7 = []byte(`{"foo": "baz"}`) - ) - expectedMatch := []opencdc.StructuredData{ - {"id": int64(1), "key": []uint8{49}, "column1": "foo", "column2": int32(123), "column3": false, "column4": big.NewRat(122, 10), "column5": big.NewRat(4, 1), "column6": value6, "column7": value7, "UppercaseColumn1": int32(1)}, - {"id": int64(2), "key": []uint8{50}, "column1": "bar", "column2": int32(456), "column3": true, "column4": big.NewRat(1342, 100), "column5": big.NewRat(8, 1), "column6": value6, "column7": value7, "UppercaseColumn1": int32(2)}, - {"id": int64(3), "key": []uint8{51}, "column1": "baz", "column2": int32(789), "column3": false, "column4": nil, "column5": big.NewRat(9, 1), "column6": value6, "column7": value7, "UppercaseColumn1": int32(3)}, - {"id": int64(4), "key": []uint8{52}, "column1": nil, "column2": nil, "column3": nil, "column4": big.NewRat(911, 10), "column5": nil, "column6": nil, "column7": nil, "UppercaseColumn1": nil}, + {"id": int64(1), "key": []uint8{49}, "column1": "foo", "column2": int32(123), "column3": false, "column4": big.NewRat(122, 10), "UppercaseColumn1": int32(1)}, + {"id": int64(2), "key": []uint8{50}, "column1": "bar", "column2": int32(456), "column3": true, "column4": big.NewRat(1342, 100), "UppercaseColumn1": int32(2)}, + {"id": int64(3), "key": []uint8{51}, "column1": "baz", "column2": int32(789), "column3": false, "column4": big.NewRat(836, 25), "UppercaseColumn1": int32(3)}, + {"id": int64(4), "key": []uint8{52}, "column1": "qux", "column2": int32(444), "column3": false, "column4": big.NewRat(911, 10), "UppercaseColumn1": int32(4)}, } for i, got := range gotFetchData { @@ -291,9 +289,7 @@ func Test_FetcherRun_Initial(t *testing.T) { is.Equal("", cmp.Diff( expectedMatch[i], got.Payload, - cmp.Comparer(func(x, y *big.Rat) bool { - return x.Cmp(y) == 0 - }), + test.BigRatComparer, )) is.Equal(got.Position, position.SnapshotPosition{ @@ -308,7 +304,7 @@ func Test_FetcherRun_Initial(t *testing.T) { func Test_FetcherRun_Resume(t *testing.T) { var ( pool = test.ConnectPool(context.Background(), t, test.RegularConnString) - table = test.SetupTestTable(context.Background(), t, pool) + table = test.SetupTable(context.Background(), t, pool) is = is.New(t) out = make(chan []FetchData) ctx = test.Context(t) @@ -333,7 +329,7 @@ func Test_FetcherRun_Resume(t *testing.T) { ctx = tt.Context(ctx) defer close(out) - if err := f.Validate(ctx); err != nil { + if err := f.Init(ctx); err != nil { return err } return f.Run(ctx) @@ -359,15 +355,10 @@ func Test_FetcherRun_Resume(t *testing.T) { "column1": "baz", "column2": int32(789), "column3": false, - "column4": nil, - "column5": big.NewRat(9, 1), - "column6": []byte(`{"foo": "bar"}`), - "column7": []byte(`{"foo": "baz"}`), + "column4": big.NewRat(836, 25), "UppercaseColumn1": int32(3), }, - cmp.Comparer(func(x, y *big.Rat) bool { - return x.Cmp(y) == 0 - }), + test.BigRatComparer, ), ) @@ -460,31 +451,6 @@ func Test_send(t *testing.T) { is.Equal(err, context.Canceled) } -func Test_FetchWorker_buildRecordData(t *testing.T) { - var ( - is = is.New(t) - now = time.Now().UTC() - - // special case fields - fields = []pgconn.FieldDescription{{Name: "id"}, {Name: "time"}} - values = []any{1, now} - expectValues = []any{1, now} - ) - - key, payload, err := (&FetchWorker{ - conf: FetchConfig{Table: "mytable", Key: "id"}, - }).buildRecordData(fields, values) - - is.NoErr(err) - is.Equal(len(payload), 2) - for i, fd := range fields { - is.Equal(payload[fd.Name], expectValues[i]) - } - - is.Equal(len(key), 1) - is.Equal(key["id"], 1) -} - func Test_FetchWorker_updateSnapshotEnd(t *testing.T) { var ( is = is.New(t) @@ -493,7 +459,7 @@ func Test_FetchWorker_updateSnapshotEnd(t *testing.T) { table = strings.ToUpper(test.RandomIdentifier(t)) ) - test.SetupTestTableWithName(ctx, t, pool, table) + test.SetupTableWithName(ctx, t, pool, table) tx, err := pool.Begin(ctx) is.NoErr(err) @@ -519,7 +485,7 @@ func Test_FetchWorker_updateSnapshotEnd(t *testing.T) { Table: table, Key: "UppercaseColumn1", }}, - expected: 3, + expected: 4, }, { desc: "skip update when set", @@ -555,7 +521,7 @@ func Test_FetchWorker_updateSnapshotEnd(t *testing.T) { func Test_FetchWorker_createCursor(t *testing.T) { var ( pool = test.ConnectPool(context.Background(), t, test.RegularConnString) - table = test.SetupTestTable(context.Background(), t, pool) + table = test.SetupTable(context.Background(), t, pool) is = is.New(t) ctx = test.Context(t) ) diff --git a/source/snapshot/iterator.go b/source/snapshot/iterator.go index bafa26f..b1910ef 100644 --- a/source/snapshot/iterator.go +++ b/source/snapshot/iterator.go @@ -179,7 +179,7 @@ func (i *Iterator) initFetchers(ctx context.Context) error { WithAvroSchema: i.conf.WithAvroSchema, }) - if err := w.Validate(ctx); err != nil { + if err := w.Init(ctx); err != nil { errs = append(errs, fmt.Errorf("failed to validate table fetcher %q config: %w", t, err)) } diff --git a/source/snapshot/iterator_test.go b/source/snapshot/iterator_test.go index 8baa792..9b450cf 100644 --- a/source/snapshot/iterator_test.go +++ b/source/snapshot/iterator_test.go @@ -30,7 +30,7 @@ func Test_Iterator_NextN(t *testing.T) { var ( ctx = test.Context(t) pool = test.ConnectPool(ctx, t, test.RegularConnString) - table = test.SetupTestTable(ctx, t, pool) + table = test.SetupTable(ctx, t, pool) ) t.Run("success", func(t *testing.T) { diff --git a/source/types/types.go b/source/types/types.go index 87b1c67..d9de8ba 100644 --- a/source/types/types.go +++ b/source/types/types.go @@ -15,15 +15,44 @@ package types import ( + "reflect" + "github.com/jackc/pgx/v5/pgtype" ) var ( Numeric = NumericFormatter{} - UUID = UUIDFormatter{} + + UUID = UUIDFormatter{} ) -func Format(oid uint32, v any) (any, error) { +// Format formats the input value v with the corresponding Postgres OID +// into an appropriate Go value (that can later be serialized with Avro). +// If the input value is nullable (i.e. isNotNull is false), then the method +// returns a pointer. +// +// The following types are currently not supported: +// bit, varbit, box, char(n), cidr, circle, inet, interval, line, lseg, +// macaddr, macaddr8, money, path, pg_lsn, pg_snapshot, point, polygon, +// time, timetz, tsquery, tsvector, xml +func Format(oid uint32, v any, isNotNull bool) (any, error) { + if v == nil { + return nil, nil + } + + val, err := format(oid, v) + if err != nil { + return nil, err + } + + if reflect.TypeOf(val).Kind() != reflect.Ptr && !isNotNull { + return GetPointer(val), nil + } + + return val, nil +} + +func format(oid uint32, v any) (any, error) { if oid == pgtype.UUIDOID { return UUID.Format(v) } @@ -42,3 +71,33 @@ func Format(oid uint32, v any) (any, error) { return t, nil } } + +func GetPointer(v any) any { + rv := reflect.ValueOf(v) + + // If the value is nil or invalid, return nil + if !rv.IsValid() { + return nil + } + + // If it's already a pointer, return it as-is + if rv.Kind() == reflect.Ptr { + return rv.Interface() + } + + if rv.Kind() == reflect.Slice || rv.Kind() == reflect.Map || rv.Kind() == reflect.Array { + return rv.Interface() + } + + // For non-pointer values, we need to get the address + // If the value is addressable, return its address + if rv.CanAddr() { + return rv.Addr().Interface() + } + + // If we can't get the address directly, create an addressable copy + // This happens when the interface{} contains a literal value + ptr := reflect.New(rv.Type()) + ptr.Elem().Set(rv) + return ptr.Interface() +} diff --git a/source/types/types_test.go b/source/types/types_test.go index fb1dea7..d7d2520 100644 --- a/source/types/types_test.go +++ b/source/types/types_test.go @@ -19,6 +19,7 @@ import ( "testing" "time" + "github.com/conduitio/conduit-commons/lang" "github.com/jackc/pgx/v5/pgtype" "github.com/matryer/is" ) @@ -27,11 +28,11 @@ func Test_Format(t *testing.T) { now := time.Now().UTC() tests := []struct { - name string - input []any - inputOID []uint32 - expect []any - withBuiltin bool + name string + input []any + inputOID []uint32 + expect []any + expectNullable []any }{ { name: "int float string bool", @@ -44,6 +45,9 @@ func Test_Format(t *testing.T) { expect: []any{ 1021, 199.2, "foo", true, }, + expectNullable: []any{ + lang.Ptr(1021), lang.Ptr(199.2), lang.Ptr("foo"), lang.Ptr(true), + }, }, { name: "pgtype.Numeric", @@ -56,6 +60,9 @@ func Test_Format(t *testing.T) { expect: []any{ big.NewRat(122121, 10000), big.NewRat(101, 1), big.NewRat(0, 1), nil, nil, }, + expectNullable: []any{ + big.NewRat(122121, 10000), big.NewRat(101, 1), big.NewRat(0, 1), nil, nil, + }, }, { name: "builtin time.Time", @@ -68,31 +75,40 @@ func Test_Format(t *testing.T) { expect: []any{ now, }, - withBuiltin: true, + expectNullable: []any{ + lang.Ptr(now), + }, }, { name: "uuid", input: []any{ - [16]uint8{0xbd, 0x94, 0xee, 0x0b, 0x56, 0x4f, 0x40, 0x88, 0xbf, 0x4e, 0x8d, 0x5e, 0x62, 0x6c, 0xaf, 0x66}, nil, + [16]uint8{0xbd, 0x94, 0xee, 0x0b, 0x56, 0x4f, 0x40, 0x88, 0xbf, 0x4e, 0x8d, 0x5e, 0x62, 0x6c, 0xaf, 0x66}, + nil, }, inputOID: []uint32{ pgtype.UUIDOID, pgtype.UUIDOID, }, expect: []any{ - "bd94ee0b-564f-4088-bf4e-8d5e626caf66", "", + "bd94ee0b-564f-4088-bf4e-8d5e626caf66", nil, + }, + expectNullable: []any{ + lang.Ptr("bd94ee0b-564f-4088-bf4e-8d5e626caf66"), nil, }, }, } - _ = time.Now() for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { is := is.New(t) for i, in := range tc.input { - v, err := Format(tc.inputOID[i], in) + v, err := Format(tc.inputOID[i], in, true) is.NoErr(err) is.Equal(v, tc.expect[i]) + + vNullable, err := Format(tc.inputOID[i], in, false) + is.NoErr(err) + is.Equal(vNullable, tc.expectNullable[i]) } }) } diff --git a/source/types/uuid.go b/source/types/uuid.go index c2f8093..c941e00 100644 --- a/source/types/uuid.go +++ b/source/types/uuid.go @@ -22,8 +22,8 @@ import ( type UUIDFormatter struct{} -// Format takes a slice of bytes and returns a UUID in string format -// Returns error when byte array cannot be parsed. +// Format transforms a byte array into a UUID in string format. +// Returns an error if the byte array cannot be parsed. func (UUIDFormatter) Format(v any) (string, error) { if v == nil { return "", nil diff --git a/source_integration_test.go b/source_integration_test.go index 527078b..4f98034 100644 --- a/source_integration_test.go +++ b/source_integration_test.go @@ -17,30 +17,161 @@ package postgres import ( "context" "fmt" + "math/big" "strings" "testing" + "time" + "github.com/Masterminds/squirrel" "github.com/conduitio/conduit-commons/config" + "github.com/conduitio/conduit-commons/opencdc" + "github.com/conduitio/conduit-connector-postgres/internal" "github.com/conduitio/conduit-connector-postgres/source" "github.com/conduitio/conduit-connector-postgres/source/logrepl" "github.com/conduitio/conduit-connector-postgres/test" sdk "github.com/conduitio/conduit-connector-sdk" + "github.com/conduitio/conduit-connector-sdk/schema" + "github.com/google/go-cmp/cmp" + "github.com/google/uuid" + "github.com/jackc/pgx/v5" "github.com/matryer/is" + "github.com/shopspring/decimal" ) -func TestSource_Open(t *testing.T) { - is := is.New(t) - ctx := test.Context(t) - conn := test.ConnectSimple(ctx, t, test.RepmgrConnString) +var ( + slotName = "conduitslot1" + publicationName = "conduitpub1" +) + +func TestSource_ReadN_Snapshot(t *testing.T) { + testCases := []struct { + name string + notNullOnly bool + }{ + { + name: "with null columns", + notNullOnly: false, + }, + { + name: "not only only", + notNullOnly: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + is := is.New(t) + ctx := test.Context(t) + conn := test.ConnectSimple(ctx, t, test.RepmgrConnString) + + tableName := createTableWithManyTypes(ctx, t) + insertRow(ctx, is, conn, tableName, 1, tc.notNullOnly) + + s := openSource(ctx, is, tableName) + t.Cleanup(func() { + is.NoErr(logrepl.Cleanup(context.Background(), logrepl.CleanupConfig{ + URL: test.RepmgrConnString, + SlotName: slotName, + PublicationName: publicationName, + })) + is.NoErr(s.Teardown(ctx)) + }) + + // Read, ack, and assert the snapshot record is OK + rec := readAndAck(ctx, is, s) + assertRecordOK(is, tableName, rec, 1, tc.notNullOnly) + }) + } +} + +func TestSource_ReadN_CDC(t *testing.T) { + testCases := []struct { + name string + notNullOnly bool + }{ + { + name: "with null columns", + notNullOnly: false, + }, + { + name: "not only only", + notNullOnly: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + is := is.New(t) + ctx := test.Context(t) + conn := test.ConnectSimple(ctx, t, test.RepmgrConnString) + + tableName := createTableWithManyTypes(ctx, t) + + s := openSource(ctx, is, tableName) + t.Cleanup(func() { + is.NoErr(logrepl.Cleanup(context.Background(), logrepl.CleanupConfig{ + URL: test.RepmgrConnString, + SlotName: slotName, + PublicationName: publicationName, + })) + is.NoErr(s.Teardown(ctx)) + }) + + insertRow(ctx, is, conn, tableName, 1, tc.notNullOnly) + // Read, ack, and assert the CDC record is OK + rec := readAndAck(ctx, is, s) + assertRecordOK(is, tableName, rec, 1, tc.notNullOnly) + }) + } +} + +func TestSource_ReadN_Delete(t *testing.T) { + t.Skip("Skipping until this issue is resolved: https://github.com/ConduitIO/conduit-connector-postgres/issues/301") + testCases := []struct { + name string + notNullOnly bool + }{ + { + name: "with null columns", + notNullOnly: false, + }, + { + name: "not only only", + notNullOnly: true, + }, + } - // Be sure primary key discovering works correctly on - // table names with capital letters - tableName := strings.ToUpper(test.RandomIdentifier(t)) - test.SetupTestTableWithName(ctx, t, conn, tableName) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + is := is.New(t) + ctx := test.Context(t) + conn := test.ConnectSimple(ctx, t, test.RepmgrConnString) - slotName := "conduitslot1" - publicationName := "conduitpub1" + tableName := createTableWithManyTypes(ctx, t) + + s := openSource(ctx, is, tableName) + t.Cleanup(func() { + is.NoErr(logrepl.Cleanup(context.Background(), logrepl.CleanupConfig{ + URL: test.RepmgrConnString, + SlotName: slotName, + PublicationName: publicationName, + })) + is.NoErr(s.Teardown(ctx)) + }) + + insertRow(ctx, is, conn, tableName, 1, tc.notNullOnly) + // Read, ack, and assert the CDC record is OK + cdcRec := readAndAck(ctx, is, s) + assertRecordOK(is, tableName, cdcRec, 1, tc.notNullOnly) + + deleteRow(ctx, is, conn, tableName, 1) + deleteRec := readAndAck(ctx, is, s) + is.Equal(opencdc.OperationDelete, deleteRec.Operation) + }) + } +} +func openSource(ctx context.Context, is *is.I, tableName string) sdk.Source { s := NewSource() err := sdk.Util.ParseConfig( ctx, @@ -60,14 +191,274 @@ func TestSource_Open(t *testing.T) { err = s.Open(ctx, nil) is.NoErr(err) - defer func() { - is.NoErr(logrepl.Cleanup(context.Background(), logrepl.CleanupConfig{ - URL: test.RepmgrConnString, - SlotName: slotName, - PublicationName: publicationName, - })) - is.NoErr(s.Teardown(ctx)) - }() + return s +} + +func readAndAck(ctx context.Context, is *is.I, s sdk.Source) opencdc.Record { + recs, err := s.ReadN(ctx, 1) + is.NoErr(err) + is.Equal(1, len(recs)) + + err = s.Ack(ctx, recs[0].Position) + is.NoErr(err) + + return recs[0] +} + +func createTableWithManyTypes(ctx context.Context, t *testing.T) string { + is := is.New(t) + + conn := test.ConnectSimple(ctx, t, test.RepmgrConnString) + // Verify we can discover the primary key even when the table name + // contains capital letters. + table := strings.ToUpper(test.RandomIdentifier(t)) + + query := fmt.Sprintf(`CREATE TABLE %q ( + id integer PRIMARY KEY, + col_bytea bytea, + col_bytea_not_null bytea NOT NULL, + col_varchar varchar(30), + col_varchar_not_null varchar(30) NOT NULL, + col_date date, + col_date_not_null date NOT NULL, + col_float4 float4, + col_float4_not_null float4 NOT NULL, + col_float8 float8, + col_float8_not_null float8 NOT NULL, + col_int2 int2, + col_int2_not_null int2 NOT NULL, + col_int4 int4, + col_int4_not_null int4 NOT NULL, + col_int8 int8, + col_int8_not_null int8 NOT NULL, + col_numeric numeric(8,2), + col_numeric_not_null numeric(8,2) NOT NULL, + col_text text, + col_text_not_null text NOT NULL, + col_timestamp timestamp, + col_timestamp_not_null timestamp NOT NULL, + col_timestamptz timestamptz, + col_timestamptz_not_null timestamptz NOT NULL, + col_uuid uuid, + col_uuid_not_null uuid NOT NULL, + col_json json, + col_json_not_null json NOT NULL, + col_jsonb jsonb, + col_jsonb_not_null jsonb NOT NULL, + col_bool bool, + col_bool_not_null bool NOT NULL, + col_serial serial, + col_serial_not_null serial NOT NULL, + col_smallserial smallserial, + col_smallserial_not_null smallserial NOT NULL, + col_bigserial bigserial, + col_bigserial_not_null bigserial NOT NULL +)`, table) + _, err := conn.Exec(ctx, query) + is.NoErr(err) + + t.Cleanup(func() { + query := `DROP TABLE %q` + query = fmt.Sprintf(query, table) + _, err := conn.Exec(context.Background(), query) + is.NoErr(err) + }) + + return table +} + +// insertRow inserts a row using the values provided by generatePayloadData. +// if notNullOnly is true, only NOT NULL columns are inserted. +func insertRow(ctx context.Context, is *is.I, conn *pgx.Conn, table string, rowNumber int, notNullOnly bool) { + rec := generatePayloadData(rowNumber, false) + + var columns []string + var values []interface{} + for key, value := range rec { + // the database generates serial values + if strings.Contains(key, "serial") { + continue + } + // check if only NOT NULL columns are needed (names ends in _not_null) + // id is an exception + if notNullOnly && !strings.HasSuffix(key, "_not_null") && key != "id" { + continue + } + + columns = append(columns, key) + // col_numeric is a big.Rat, so we convert it to a string + if strings.HasPrefix(key, "col_numeric") { + values = append(values, decimalString(value)) + } else { + values = append(values, value) + } + } + + query, args, err := squirrel.Insert(internal.WrapSQLIdent(table)). + Columns(columns...). + Values(values...). + PlaceholderFormat(squirrel.Dollar). + ToSql() + + is.NoErr(err) + + _, err = conn.Exec(ctx, query, args...) + is.NoErr(err) +} + +func deleteRow(ctx context.Context, is *is.I, conn *pgx.Conn, table string, rowNumber int) { + query, args, err := squirrel.Delete(internal.WrapSQLIdent(table)). + Where(squirrel.Eq{"id": rowNumber}). + PlaceholderFormat(squirrel.Dollar). + ToSql() + is.NoErr(err) + + _, err = conn.Exec(ctx, query, args...) + is.NoErr(err) +} + +func generatePayloadData(id int, notNullOnly bool) opencdc.StructuredData { + // Add a time zone offset + rowTS := assert(time.Parse(time.RFC3339, fmt.Sprintf("2022-01-21T17:04:05+%02d:00", id))) + rowTS = rowTS.Add(time.Duration(id) * time.Hour) + + rowUUID := assert(uuid.Parse(fmt.Sprintf("a74a9875-978e-4832-b1b8-6b0f8793a%03d", id))) + idInt64 := int64(id) + numericVal := big.NewRat(int64(100+id), 10) + + rec := opencdc.StructuredData{ + "id": id, + + "col_bytea": []uint8(fmt.Sprintf("col_bytea_%v", id)), + "col_bytea_not_null": []uint8(fmt.Sprintf("col_bytea_not_null_%v", id)), + "col_varchar": fmt.Sprintf("col_varchar_%v", id), + "col_varchar_not_null": fmt.Sprintf("col_varchar_not_null_%v", id), + "col_text": fmt.Sprintf("col_text_%v", id), + "col_text_not_null": fmt.Sprintf("col_text_not_null_%v", id), + + "col_uuid": rowUUID, + "col_uuid_not_null": rowUUID, + + "col_json": []uint8(fmt.Sprintf(`{"key": "json-value-%v"}`, id)), + "col_json_not_null": []uint8(fmt.Sprintf(`{"key": "json-not-value-%v"}`, id)), + "col_jsonb": []uint8(fmt.Sprintf(`{"key": "jsonb-value-%v"}`, id)), + "col_jsonb_not_null": []uint8(fmt.Sprintf(`{"key": "jsonb-not-value-%v"}`, id)), + + "col_float4": float32(id) / 10, + "col_float4_not_null": float32(id) / 10, + "col_float8": float64(id) / 10, + "col_float8_not_null": float64(id) / 10, + "col_int2": id % 32768, + "col_int2_not_null": id % 32768, + "col_int4": id, + "col_int4_not_null": id, + "col_int8": idInt64, + "col_int8_not_null": idInt64, + + "col_numeric": numericVal, + "col_numeric_not_null": numericVal, + + // NB: these values are not used in insert queries, but we assume + // the test rows will always be inserted in order, i.e., + // test row 1, then test row 2, etc. + "col_serial": id, + "col_serial_not_null": id, + "col_smallserial": id, + "col_smallserial_not_null": id, + "col_bigserial": idInt64, + "col_bigserial_not_null": idInt64, + + "col_date": rowTS.Truncate(24 * time.Hour), + "col_date_not_null": rowTS.Truncate(24 * time.Hour), + "col_timestamp": rowTS.UTC(), + "col_timestamp_not_null": rowTS.UTC(), + "col_timestamptz": rowTS, + "col_timestamptz_not_null": rowTS, + + "col_bool": id%2 == 0, + "col_bool_not_null": id%2 == 1, + } + + if notNullOnly { + for key := range rec { + if !strings.HasSuffix(key, "_not_null") && !strings.HasSuffix(key, "serial") && key != "id" { + rec[key] = nil + } + } + } + + return rec +} + +func decimalString(v interface{}) string { + return decimal.NewFromBigRat(v.(*big.Rat), 2).String() +} + +// assertRecordOK asserts that the input record has a schema and that its payload +// is what we expect (based on the ID and what columns are included). +func assertRecordOK(is *is.I, tableName string, gotRecord opencdc.Record, id int, notNullOnly bool) { + is.Helper() + + is.True(gotRecord.Key != nil) + is.True(gotRecord.Payload.After != nil) + + assertSchemaPresent(is, tableName, gotRecord) + assertPayloadOK(is, gotRecord, id, notNullOnly) +} + +func assertSchemaPresent(is *is.I, tableName string, gotRecord opencdc.Record) { + payloadSchemaSubject, err := gotRecord.Metadata.GetPayloadSchemaSubject() + is.NoErr(err) + is.Equal(tableName+"_payload", payloadSchemaSubject) + payloadSchemaVersion, err := gotRecord.Metadata.GetPayloadSchemaVersion() + is.NoErr(err) + is.Equal(1, payloadSchemaVersion) + + keySchemaSubject, err := gotRecord.Metadata.GetKeySchemaSubject() + is.NoErr(err) + is.Equal(tableName+"_key", keySchemaSubject) + keySchemaVersion, err := gotRecord.Metadata.GetKeySchemaVersion() + is.NoErr(err) + is.Equal(1, keySchemaVersion) +} + +// assertPayloadOK decodes the record's payload and asserts that the payload +// is what we expect (based on the ID and what columns are included). +func assertPayloadOK(is *is.I, record opencdc.Record, rowNum int, notNullOnly bool) { + is.Helper() + + sch, err := schema.Get( + context.Background(), + assert(record.Metadata.GetPayloadSchemaSubject()), + assert(record.Metadata.GetPayloadSchemaVersion()), + ) + is.NoErr(err) + + got := opencdc.StructuredData{} + err = sch.Unmarshal(record.Payload.After.Bytes(), &got) + is.NoErr(err) + + want := expectedData(rowNum, notNullOnly) + + is.Equal("", cmp.Diff(want, got, test.BigRatComparer)) // expected different payload (-want, +got) +} + +// expectedData creates an opencdc.StructuredData with expected keys and values +// based on the ID and the columns (NOT NULL columns only or all columns). +// Its output is different generatePayloadData, because certain values are written +// into the test table as one type, but read as another (e.g., we use UUID objects +// when inserting test data, but they are read as strings). +func expectedData(id int, notNullOnly bool) opencdc.StructuredData { + rec := generatePayloadData(id, notNullOnly) + + for key, value := range rec { + // UUID are written as byte arrays but read as strings. + if strings.HasPrefix(key, "col_uuid") && value != nil { + rec[key] = value.(uuid.UUID).String() + } + } + + return rec } func TestSource_ParseConfig(t *testing.T) { @@ -113,126 +504,9 @@ func TestSource_ParseConfig(t *testing.T) { } } -func TestSource_Read(t *testing.T) { - ctx := test.Context(t) - is := is.New(t) - - conn := test.ConnectSimple(ctx, t, test.RegularConnString) - table := setupSourceTable(ctx, t, conn) - insertSourceRow(ctx, t, conn, table) - - s := NewSource() - err := sdk.Util.ParseConfig( - ctx, - map[string]string{ - "url": test.RepmgrConnString, - "tables": table, - "snapshotMode": "initial", - "cdcMode": "logrepl", - }, - s.Config(), - Connector.NewSpecification().SourceParams, - ) - is.NoErr(err) - - err = s.Open(ctx, nil) - is.NoErr(err) - - recs, err := s.ReadN(ctx, 1) - is.NoErr(err) - - fmt.Println(recs) -} - -// setupSourceTable creates a new table with all types and returns its name. -func setupSourceTable(ctx context.Context, t *testing.T, conn test.Querier) string { - is := is.New(t) - table := test.RandomIdentifier(t) - // todo still need to support: - // bit, varbit, box, char(n), cidr, circle, inet, interval, line, lseg, - // macaddr, macaddr8, money, path, pg_lsn, pg_snapshot, point, polygon, - // time, timetz, tsquery, tsvector, xml - query := ` - CREATE TABLE %s ( - id bigserial PRIMARY KEY, - col_boolean boolean, - col_bytea bytea, - col_varchar varchar(10), - col_date date, - col_float4 float4, - col_float8 float8, - col_int2 int2, - col_int4 int4, - col_int8 int8, - col_json json, - col_jsonb jsonb, - col_numeric numeric(8,2), - col_serial2 serial2, - col_serial4 serial4, - col_serial8 serial8, - col_text text, - col_timestamp timestamp, - col_timestamptz timestamptz, - col_uuid uuid - )` - query = fmt.Sprintf(query, table) - _, err := conn.Exec(ctx, query) - is.NoErr(err) - - t.Cleanup(func() { - query := `DROP TABLE %s` - query = fmt.Sprintf(query, table) - _, err := conn.Exec(context.Background(), query) - is.NoErr(err) - }) - return table -} - -func insertSourceRow(ctx context.Context, t *testing.T, conn test.Querier, table string) { - is := is.New(t) - query := ` - INSERT INTO %s ( - col_boolean, - col_bytea, - col_varchar, - col_date, - col_float4, - col_float8, - col_int2, - col_int4, - col_int8, - col_json, - col_jsonb, - col_numeric, - col_serial2, - col_serial4, - col_serial8, - col_text, - col_timestamp, - col_timestamptz, - col_uuid - ) VALUES ( - true, -- col_boolean - '\x07', -- col_bytea - '9', -- col_varchar - '2022-03-14', -- col_date - 15, -- col_float4 - 16.16, -- col_float8 - 32767, -- col_int2 - 2147483647, -- col_int4 - 9223372036854775807, -- col_int8 - '{"foo": "bar"}', -- col_json - '{"foo": "baz"}', -- col_jsonb - '292929.29', -- col_numeric - 32767, -- col_serial2 - 2147483647, -- col_serial4 - 9223372036854775807, -- col_serial8 - 'foo bar baz', -- col_text - '2022-03-14 15:16:17', -- col_timestamp - '2022-03-14 15:16:17-08', -- col_timestamptz - 'bd94ee0b-564f-4088-bf4e-8d5e626caf66' -- col_uuid - )` - query = fmt.Sprintf(query, table) - _, err := conn.Exec(ctx, query) - is.NoErr(err) +func assert[T any](val T, err error) T { + if err != nil { + panic(err) + } + return val } diff --git a/test/helper.go b/test/helper.go index 6e72e21..37d240f 100644 --- a/test/helper.go +++ b/test/helper.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "math/big" "strconv" "strings" "testing" @@ -25,6 +26,7 @@ import ( "github.com/conduitio/conduit-commons/csync" "github.com/conduitio/conduit-connector-postgres/source/cpool" + "github.com/google/go-cmp/cmp" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" @@ -44,9 +46,8 @@ const RegularConnString = "postgres://meroxauser:meroxapass@127.0.0.1:5433/merox const TestTableAvroSchemaV1 = `{ "type": "record", "name": "%s", - "fields": - [ - {"name":"UppercaseColumn1","type":"int"}, + "fields": [ + {"name":"UppercaseColumn1","type":"int"}, {"name":"column1","type":"string"}, {"name":"column2","type":"int"}, {"name":"column3","type":"boolean"}, @@ -60,17 +61,6 @@ const TestTableAvroSchemaV1 = `{ "scale": 3 } }, - { - "name": "column5", - "type": - { - "type": "bytes", - "logicalType": "decimal", - "precision": 5 - } - }, - {"name":"column6","type":"bytes"}, - {"name":"column7","type":"bytes"}, {"name":"id","type":"long"}, {"name":"key","type":"bytes"} ] @@ -80,11 +70,10 @@ const TestTableAvroSchemaV1 = `{ const TestTableAvroSchemaV2 = `{ "type": "record", "name": "%s", - "fields": - [ - {"name":"UppercaseColumn1","type":"int"}, + "fields": [ + {"name":"UppercaseColumn1","type":"int"}, {"name":"column1","type":"string"}, - {"name":"column101","type":{"type":"long","logicalType":"local-timestamp-micros"}}, + {"name":"column101","type":["null", {"type":"long","logicalType":"local-timestamp-micros"}]}, {"name":"column2","type":"int"}, {"name":"column3","type":"boolean"}, { @@ -97,17 +86,6 @@ const TestTableAvroSchemaV2 = `{ "scale": 3 } }, - { - "name": "column5", - "type": - { - "type": "bytes", - "logicalType": "decimal", - "precision": 5 - } - }, - {"name":"column6","type":"bytes"}, - {"name":"column7","type":"bytes"}, {"name":"id","type":"long"}, {"name":"key","type":"bytes"} ] @@ -117,15 +95,12 @@ const TestTableAvroSchemaV2 = `{ const TestTableAvroSchemaV3 = `{ "type": "record", "name": "%s", - "fields": - [ - {"name":"UppercaseColumn1","type":"int"}, + "fields": [ + {"name":"UppercaseColumn1","type":"int"}, {"name":"column1","type":"string"}, - {"name":"column101","type":{"type":"long","logicalType":"local-timestamp-micros"}}, + {"name":"column101","type":["null", {"type":"long","logicalType":"local-timestamp-micros"}]}, {"name":"column2","type":"int"}, {"name":"column3","type":"boolean"}, - {"name":"column6","type":"bytes"}, - {"name":"column7","type":"bytes"}, {"name":"id","type":"long"}, {"name":"key","type":"bytes"} ] @@ -145,17 +120,18 @@ const TestTableKeyAvroSchema = `{ const testTableCreateQuery = ` CREATE TABLE %q ( id bigserial PRIMARY KEY, - key bytea, - column1 varchar(256), - column2 integer, - column3 boolean, - column4 numeric(16,3), - column5 numeric(5), - column6 jsonb, - column7 json, - "UppercaseColumn1" integer + key bytea not null, + column1 varchar(256) not null, + column2 integer not null, + column3 boolean not null, + column4 numeric(16,3) not null, + "UppercaseColumn1" integer not null )` +var BigRatComparer = cmp.Comparer(func(x, y *big.Rat) bool { + return x.Cmp(y) == 0 +}) + type Querier interface { Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) @@ -185,14 +161,14 @@ func ConnectSimple(ctx context.Context, t *testing.T, connString string) *pgx.Co return conn.Conn() } -// SetupTestTable creates a new table and returns its name. -func SetupEmptyTestTable(ctx context.Context, t *testing.T, conn Querier) string { +// SetupEmptyTable creates an empty test table and returns its name. +func SetupEmptyTable(ctx context.Context, t *testing.T, conn Querier) string { table := RandomIdentifier(t) - SetupEmptyTestTableWithName(ctx, t, conn, table) + SetupEmptyTableWithName(ctx, t, conn, table) return table } -func SetupEmptyTestTableWithName(ctx context.Context, t *testing.T, conn Querier, table string) { +func SetupEmptyTableWithName(ctx context.Context, t *testing.T, conn Querier, table string) { is := is.New(t) query := fmt.Sprintf(testTableCreateQuery, table) @@ -207,25 +183,26 @@ func SetupEmptyTestTableWithName(ctx context.Context, t *testing.T, conn Querier }) } -func SetupTestTableWithName(ctx context.Context, t *testing.T, conn Querier, table string) { +// SetupTableWithName creates a test table with a few row inserted into it. +func SetupTableWithName(ctx context.Context, t *testing.T, conn Querier, table string) { is := is.New(t) - SetupEmptyTestTableWithName(ctx, t, conn, table) + SetupEmptyTableWithName(ctx, t, conn, table) query := ` - INSERT INTO %q (key, column1, column2, column3, column4, column5, column6, column7, "UppercaseColumn1") - VALUES ('1', 'foo', 123, false, 12.2, 4, '{"foo": "bar"}', '{"foo": "baz"}', 1), - ('2', 'bar', 456, true, 13.42, 8, '{"foo": "bar"}', '{"foo": "baz"}', 2), - ('3', 'baz', 789, false, null, 9, '{"foo": "bar"}', '{"foo": "baz"}', 3), - ('4', null, null, null, 91.1, null, null, null, null)` + INSERT INTO %q (key, column1, column2, column3, column4, "UppercaseColumn1") + VALUES ('1', 'foo', 123, false, 12.2, 1), + ('2', 'bar', 456, true, 13.42, 2), + ('3', 'baz', 789, false, 33.44, 3), + ('4', 'qux', 444, false, 91.1, 4)` query = fmt.Sprintf(query, table) _, err := conn.Exec(ctx, query) is.NoErr(err) } -// SetupTestTable creates a new table and returns its name. -func SetupTestTable(ctx context.Context, t *testing.T, conn Querier) string { +// SetupTable creates a new table and returns its name. +func SetupTable(ctx context.Context, t *testing.T, conn Querier) string { table := RandomIdentifier(t) - SetupTestTableWithName(ctx, t, conn, table) + SetupTableWithName(ctx, t, conn, table) return table }