diff --git a/internal/config/config.go b/internal/config/config.go index d5ede1c..fd1253f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -32,6 +32,7 @@ type Config struct { PollInterval time.Duration ClientTimeout time.Duration // max time the http request can last PingTimeout time.Duration // max time allowed for ping + HeartbeatInterval time.Duration CanUseMultipleCatalogs bool DriverName string DriverVersion string @@ -70,6 +71,7 @@ func (c *Config) DeepCopy() *Config { PollInterval: c.PollInterval, ClientTimeout: c.ClientTimeout, PingTimeout: c.PingTimeout, + HeartbeatInterval: c.HeartbeatInterval, CanUseMultipleCatalogs: c.CanUseMultipleCatalogs, DriverName: c.DriverName, DriverVersion: c.DriverVersion, @@ -189,6 +191,7 @@ func WithDefaults() *Config { PollInterval: 1 * time.Second, ClientTimeout: 900 * time.Second, PingTimeout: 60 * time.Second, + HeartbeatInterval: 30 * time.Second, CanUseMultipleCatalogs: true, DriverName: "godatabrickssqlconnector", // important. Do not change ThriftProtocol: "binary", diff --git a/internal/fetcher/fetcher.go b/internal/fetcher/fetcher.go index 8430ff0..684a768 100644 --- a/internal/fetcher/fetcher.go +++ b/internal/fetcher/fetcher.go @@ -2,6 +2,7 @@ package fetcher import ( "context" + "errors" "sync" "github.com/databricks/databricks-sql-go/driverctx" @@ -17,6 +18,13 @@ type Fetcher[OutputType any] interface { Start() (<-chan OutputType, context.CancelFunc, error) } +// An item that will be stopped/started in sync with +// a fetcher +type Overwatch interface { + Start() + Stop() +} + type concurrentFetcher[I FetchableItems[O], O any] struct { cancelChan chan bool inputChan <-chan FetchableItems[O] @@ -28,6 +36,7 @@ type concurrentFetcher[I FetchableItems[O], O any] struct { ctx context.Context cancelFunc context.CancelFunc *dbsqllog.DBSQLLogger + overWatch Overwatch } func (rf *concurrentFetcher[I, O]) Err() error { @@ -40,6 +49,9 @@ func (f *concurrentFetcher[I, O]) Start() (<-chan O, context.CancelFunc, error) f.start.Do(func() { // wait group for the worker routines var wg sync.WaitGroup + if f.overWatch != nil { + f.overWatch.Start() + } for i := 0; i < f.nWorkers; i++ { @@ -64,6 +76,9 @@ func (f *concurrentFetcher[I, O]) Start() (<-chan O, context.CancelFunc, error) wg.Wait() f.logger().Trace().Msg("concurrent fetcher closing output channel") close(f.outChan) + if f.overWatch != nil { + f.overWatch.Stop() + } }() // We return a cancel function so that the client can @@ -98,7 +113,7 @@ func (f *concurrentFetcher[I, O]) logger() *dbsqllog.DBSQLLogger { return f.DBSQLLogger } -func NewConcurrentFetcher[I FetchableItems[O], O any](ctx context.Context, nWorkers, maxItemsInMemory int, inputChan <-chan FetchableItems[O]) (Fetcher[O], error) { +func NewConcurrentFetcher[I FetchableItems[O], O any](ctx context.Context, nWorkers, maxItemsInMemory int, inputChan <-chan FetchableItems[O], overWatch Overwatch) (Fetcher[O], error) { if nWorkers < 1 { nWorkers = 1 } @@ -123,6 +138,7 @@ func NewConcurrentFetcher[I FetchableItems[O], O any](ctx context.Context, nWork cancelChan: stopChannel, ctx: ctx, nWorkers: nWorkers, + overWatch: overWatch, } return fetcher, nil @@ -133,10 +149,12 @@ func work[I FetchableItems[O], O any](f *concurrentFetcher[I, O], workerIndex in for { select { case <-f.cancelChan: + f.setErr(errors.New("fetcher canceled")) f.logger().Debug().Msgf("concurrent fetcher worker %d received cancel signal", workerIndex) return case <-f.ctx.Done(): + f.setErr(f.ctx.Err()) f.logger().Debug().Msgf("concurrent fetcher worker %d context done", workerIndex) return diff --git a/internal/fetcher/fetcher_test.go b/internal/fetcher/fetcher_test.go index dbe6ced..8087c90 100644 --- a/internal/fetcher/fetcher_test.go +++ b/internal/fetcher/fetcher_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/pkg/errors" + "github.com/stretchr/testify/assert" ) // Create a mock struct for FetchableItems @@ -32,6 +32,13 @@ func (m *mockFetchableItem) Fetch(ctx context.Context) ([]*mockOutput, error) { var _ FetchableItems[[]*mockOutput] = (*mockFetchableItem)(nil) +type testOverWatch struct { + started, stopped bool +} + +func (ow *testOverWatch) Start() { ow.started = true } +func (ow *testOverWatch) Stop() { ow.stopped = true } + func TestConcurrentFetcher(t *testing.T) { t.Run("Comprehensively tests the concurrent fetcher", func(t *testing.T) { ctx := context.Background() @@ -43,8 +50,10 @@ func TestConcurrentFetcher(t *testing.T) { } close(inputChan) + ow := &testOverWatch{} + // Create a fetcher - fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 3, 3, inputChan) + fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 3, 3, inputChan, ow) if err != nil { t.Fatalf("Error creating fetcher: %v", err) } @@ -60,6 +69,9 @@ func TestConcurrentFetcher(t *testing.T) { results = append(results, result...) } + assert.True(t, ow.started) + assert.True(t, ow.stopped) + // Check if the fetcher returned the expected results expectedLen := 50 if len(results) != expectedLen { @@ -83,19 +95,20 @@ func TestConcurrentFetcher(t *testing.T) { t.Run("Cancel the concurrent fetcher", func(t *testing.T) { // Create a context with a timeout - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() // Create an input channel - inputChan := make(chan FetchableItems[[]*mockOutput], 3) - for i := 0; i < 3; i++ { - item := mockFetchableItem{item: i, wait: 1 * time.Second} + inputChan := make(chan FetchableItems[[]*mockOutput], 5) + for i := 0; i < 5; i++ { + item := mockFetchableItem{item: i, wait: 2 * time.Second} inputChan <- &item } close(inputChan) + ow := &testOverWatch{} // Create a new fetcher - fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 2, 2, inputChan) + fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 2, 2, inputChan, ow) if err != nil { t.Fatalf("Error creating fetcher: %v", err) } @@ -111,13 +124,106 @@ func TestConcurrentFetcher(t *testing.T) { cancelFunc() }() + var count int for range outChan { // Just drain the channel + count += 1 + } + + assert.Less(t, count, 5) + + err = fetcher.Err() + assert.EqualError(t, err, "fetcher canceled") + + assert.True(t, ow.started) + assert.True(t, ow.stopped) + }) + + t.Run("timeout the concurrent fetcher", func(t *testing.T) { + // Create a context with a timeout + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + // Create an input channel + inputChan := make(chan FetchableItems[[]*mockOutput], 10) + for i := 0; i < 10; i++ { + item := mockFetchableItem{item: i, wait: 1 * time.Second} + inputChan <- &item + } + close(inputChan) + + ow := &testOverWatch{} + + // Create a new fetcher + fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 2, 2, inputChan, ow) + if err != nil { + t.Fatalf("Error creating fetcher: %v", err) } - // Check if an error occurred - if err := fetcher.Err(); err != nil && !errors.Is(err, context.DeadlineExceeded) { - t.Fatalf("unexpected error: %v", err) + // Start the fetcher + outChan, _, err := fetcher.Start() + if err != nil { + t.Fatal(err) } + + var count int + for range outChan { + // Just drain the channel + count += 1 + } + + assert.Less(t, count, 10) + + err = fetcher.Err() + assert.EqualError(t, err, "context deadline exceeded") + + assert.True(t, ow.started) + assert.True(t, ow.stopped) + }) + + t.Run("context cancel the concurrent fetcher", func(t *testing.T) { + // Create a context with a timeout + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + + // Create an input channel + inputChan := make(chan FetchableItems[[]*mockOutput], 5) + for i := 0; i < 5; i++ { + item := mockFetchableItem{item: i, wait: 2 * time.Second} + inputChan <- &item + } + close(inputChan) + + ow := &testOverWatch{} + + // Create a new fetcher + fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 2, 2, inputChan, ow) + if err != nil { + t.Fatalf("Error creating fetcher: %v", err) + } + + // Start the fetcher + outChan, _, err := fetcher.Start() + if err != nil { + t.Fatal(err) + } + + // Ensure that the fetcher is cancelled successfully + go func() { + cancel() + }() + + var count int + for range outChan { + // Just drain the channel + count += 1 + } + + assert.Less(t, count, 5) + + err = fetcher.Err() + assert.EqualError(t, err, "context canceled") + + assert.True(t, ow.started) + assert.True(t, ow.stopped) }) } diff --git a/internal/rows/arrowbased/arrowRecordIterator.go b/internal/rows/arrowbased/arrowRecordIterator.go index 2b634ac..4dafb5d 100644 --- a/internal/rows/arrowbased/arrowRecordIterator.go +++ b/internal/rows/arrowbased/arrowRecordIterator.go @@ -2,6 +2,7 @@ package arrowbased import ( "context" + "database/sql/driver" "io" "github.com/apache/arrow/go/v12/arrow" @@ -9,16 +10,27 @@ import ( "github.com/databricks/databricks-sql-go/internal/config" dbsqlerr "github.com/databricks/databricks-sql-go/internal/errors" "github.com/databricks/databricks-sql-go/internal/rows/rowscanner" + "github.com/databricks/databricks-sql-go/logger" "github.com/databricks/databricks-sql-go/rows" ) -func NewArrowRecordIterator(ctx context.Context, rpi rowscanner.ResultPageIterator, bi BatchIterator, arrowSchemaBytes []byte, cfg config.Config) rows.ArrowBatchIterator { +func NewArrowRecordIterator( + ctx context.Context, + rpi rowscanner.ResultPageIterator, + bi BatchIterator, + arrowSchemaBytes []byte, + cfg config.Config, + pinger driver.Pinger, + logger *logger.DBSQLLogger) rows.ArrowBatchIterator { + ari := arrowRecordIterator{ cfg: cfg, batchIterator: bi, resultPageIterator: rpi, ctx: ctx, arrowSchemaBytes: arrowSchemaBytes, + pinger: pinger, + logger: logger, } return &ari @@ -34,6 +46,8 @@ type arrowRecordIterator struct { currentBatch SparkArrowBatch isFinished bool arrowSchemaBytes []byte + pinger driver.Pinger + logger *logger.DBSQLLogger } var _ rows.ArrowBatchIterator = (*arrowRecordIterator)(nil) @@ -175,7 +189,7 @@ func (ri *arrowRecordIterator) newBatchLoader(fr *cli_service.TFetchResultsResp) var bl BatchLoader var err error if len(rowSet.ResultLinks) > 0 { - bl, err = NewCloudBatchLoader(ri.ctx, rowSet.ResultLinks, rowSet.StartRowOffset, &ri.cfg) + bl, err = NewCloudBatchLoader(ri.ctx, rowSet.ResultLinks, rowSet.StartRowOffset, &ri.cfg, ri.pinger, ri.logger) } else { bl, err = NewLocalBatchLoader(ri.ctx, rowSet.ArrowBatches, rowSet.StartRowOffset, ri.arrowSchemaBytes, &ri.cfg) } diff --git a/internal/rows/arrowbased/arrowRecordIterator_test.go b/internal/rows/arrowbased/arrowRecordIterator_test.go index a3e4040..a7e3d7f 100644 --- a/internal/rows/arrowbased/arrowRecordIterator_test.go +++ b/internal/rows/arrowbased/arrowRecordIterator_test.go @@ -37,6 +37,7 @@ func TestArrowRecordIterator(t *testing.T) { 5000, nil, false, + true, client, "connectionId", "correlationId", @@ -55,7 +56,7 @@ func TestArrowRecordIterator(t *testing.T) { assert.Nil(t, err) cfg := *config.WithDefaults() - rs := NewArrowRecordIterator(context.Background(), rpi, bi, executeStatementResp.DirectResults.ResultSetMetadata.ArrowSchema, cfg) + rs := NewArrowRecordIterator(context.Background(), rpi, bi, executeStatementResp.DirectResults.ResultSetMetadata.ArrowSchema, cfg, nil, nil) defer rs.Close() hasNext := rs.HasNext() @@ -127,13 +128,14 @@ func TestArrowRecordIterator(t *testing.T) { 5000, nil, false, + true, client, "connectionId", "correlationId", logger) cfg := *config.WithDefaults() - rs := NewArrowRecordIterator(context.Background(), rpi, nil, nil, cfg) + rs := NewArrowRecordIterator(context.Background(), rpi, nil, nil, cfg, nil, nil) defer rs.Close() hasNext := rs.HasNext() diff --git a/internal/rows/arrowbased/arrowRows.go b/internal/rows/arrowbased/arrowRows.go index 164d9ff..209faaf 100644 --- a/internal/rows/arrowbased/arrowRows.go +++ b/internal/rows/arrowbased/arrowRows.go @@ -75,13 +75,21 @@ type arrowRowScanner struct { ctx context.Context batchIterator BatchIterator + + pinger driver.Pinger } // Make sure arrowRowScanner fulfills the RowScanner interface var _ rowscanner.RowScanner = (*arrowRowScanner)(nil) // NewArrowRowScanner returns an instance of RowScanner which handles arrow format results -func NewArrowRowScanner(resultSetMetadata *cli_service.TGetResultSetMetadataResp, rowSet *cli_service.TRowSet, cfg *config.Config, logger *dbsqllog.DBSQLLogger, ctx context.Context) (rowscanner.RowScanner, dbsqlerr.DBError) { +func NewArrowRowScanner( + resultSetMetadata *cli_service.TGetResultSetMetadataResp, + rowSet *cli_service.TRowSet, + cfg *config.Config, + logger *dbsqllog.DBSQLLogger, + ctx context.Context, + pinger driver.Pinger) (rowscanner.RowScanner, dbsqlerr.DBError) { // we take a passed in logger, rather than just using the global from dbsqllog, so that the containing rows // instance can pass in a logger with context such as correlation ID and operation ID @@ -115,7 +123,7 @@ func NewArrowRowScanner(resultSetMetadata *cli_service.TGetResultSetMetadataResp var bl BatchLoader var err2 dbsqlerr.DBError if len(rowSet.ResultLinks) > 0 { - bl, err2 = NewCloudBatchLoader(context.Background(), rowSet.ResultLinks, rowSet.StartRowOffset, cfg) + bl, err2 = NewCloudBatchLoader(context.Background(), rowSet.ResultLinks, rowSet.StartRowOffset, cfg, pinger, logger) } else { bl, err2 = NewLocalBatchLoader(context.Background(), rowSet.ArrowBatches, rowSet.StartRowOffset, schemaBytes, cfg) } @@ -146,6 +154,7 @@ func NewArrowRowScanner(resultSetMetadata *cli_service.TGetResultSetMetadataResp DBSQLLogger: logger, location: location, batchIterator: bi, + pinger: pinger, } return rs, nil @@ -326,7 +335,7 @@ func (ars *arrowRowScanner) validateRowNumber(rowNumber int64) dbsqlerr.DBError } func (ars *arrowRowScanner) GetArrowBatches(ctx context.Context, cfg config.Config, rpi rowscanner.ResultPageIterator) (dbsqlrows.ArrowBatchIterator, error) { - ri := NewArrowRecordIterator(ctx, rpi, ars.batchIterator, ars.arrowSchemaBytes, cfg) + ri := NewArrowRecordIterator(ctx, rpi, ars.batchIterator, ars.arrowSchemaBytes, cfg, ars.pinger, ars.DBSQLLogger) return ri, nil } diff --git a/internal/rows/arrowbased/arrowRows_test.go b/internal/rows/arrowbased/arrowRows_test.go index c693ff5..7eae7af 100644 --- a/internal/rows/arrowbased/arrowRows_test.go +++ b/internal/rows/arrowbased/arrowRows_test.go @@ -212,19 +212,19 @@ func TestArrowRowScanner(t *testing.T) { schema := &cli_service.TTableSchema{} metadataResp := getMetadataResp(schema) - ars, err := NewArrowRowScanner(metadataResp, rowSet, nil, nil, context.Background()) + ars, err := NewArrowRowScanner(metadataResp, rowSet, nil, nil, context.Background(), nil) assert.NotNil(t, ars) assert.Nil(t, err) assert.Equal(t, int64(0), ars.NRows()) rowSet.ArrowBatches = []*cli_service.TSparkArrowBatch{} - ars, err = NewArrowRowScanner(metadataResp, rowSet, nil, nil, context.Background()) + ars, err = NewArrowRowScanner(metadataResp, rowSet, nil, nil, context.Background(), nil) assert.NotNil(t, ars) assert.Nil(t, err) assert.Equal(t, int64(0), ars.NRows()) rowSet.ArrowBatches = []*cli_service.TSparkArrowBatch{{RowCount: 2}, {RowCount: 3}} - ars, _ = NewArrowRowScanner(metadataResp, rowSet, nil, nil, context.Background()) + ars, _ = NewArrowRowScanner(metadataResp, rowSet, nil, nil, context.Background(), nil) assert.NotNil(t, ars) assert.Equal(t, int64(5), ars.NRows()) }) @@ -236,7 +236,7 @@ func TestArrowRowScanner(t *testing.T) { schema := getAllTypesSchema() metadataResp := getMetadataResp(schema) - d, _ := NewArrowRowScanner(metadataResp, rowSet, nil, nil, context.Background()) + d, _ := NewArrowRowScanner(metadataResp, rowSet, nil, nil, context.Background(), nil) var ars *arrowRowScanner = d.(*arrowRowScanner) @@ -312,7 +312,7 @@ func TestArrowRowScanner(t *testing.T) { cfg.UseArrowNativeTimestamp = true cfg.UseArrowNativeDecimal = true - d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) var ars *arrowRowScanner = d.(*arrowRowScanner) @@ -343,14 +343,14 @@ func TestArrowRowScanner(t *testing.T) { cfg.UseArrowNativeTimestamp = true cfg.UseArrowNativeDecimal = true - _, err := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + _, err := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) require.Nil(t, err) // missing type qualifiers schema = getAllTypesSchema() schema.Columns[13].TypeDesc.Types[0].PrimitiveEntry.TypeQualifiers = nil metadataResp.Schema = schema - _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) require.NotNil(t, err) assert.True(t, strings.HasPrefix(err.Error(), "databricks: driver error: "+errArrowRowsConvertSchema+": "+errArrowRowsInvalidDecimalType)) @@ -358,7 +358,7 @@ func TestArrowRowScanner(t *testing.T) { schema = getAllTypesSchema() schema.Columns[13].TypeDesc.Types[0].PrimitiveEntry.TypeQualifiers.Qualifiers = nil metadataResp.Schema = schema - _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) require.NotNil(t, err) assert.True(t, strings.HasPrefix(err.Error(), "databricks: driver error: "+errArrowRowsConvertSchema+": "+errArrowRowsInvalidDecimalType)) @@ -366,7 +366,7 @@ func TestArrowRowScanner(t *testing.T) { schema = getAllTypesSchema() schema.Columns[13].TypeDesc.Types[0].PrimitiveEntry.TypeQualifiers.Qualifiers = map[string]*cli_service.TTypeQualifierValue{} metadataResp.Schema = schema - _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) require.NotNil(t, err) assert.True(t, strings.HasPrefix(err.Error(), "databricks: driver error: "+errArrowRowsConvertSchema+": "+errArrowRowsInvalidDecimalType)) @@ -374,7 +374,7 @@ func TestArrowRowScanner(t *testing.T) { schema = getAllTypesSchema() schema.Columns[13].TypeDesc.Types[0].PrimitiveEntry.TypeQualifiers.Qualifiers["precision"] = nil metadataResp.Schema = schema - _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) require.NotNil(t, err) assert.True(t, strings.HasPrefix(err.Error(), "databricks: driver error: "+errArrowRowsConvertSchema+": "+errArrowRowsInvalidDecimalType)) @@ -382,7 +382,7 @@ func TestArrowRowScanner(t *testing.T) { schema = getAllTypesSchema() schema.Columns[13].TypeDesc.Types[0].PrimitiveEntry.TypeQualifiers.Qualifiers["precision"].I32Value = nil metadataResp.Schema = schema - _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) require.NotNil(t, err) assert.True(t, strings.HasPrefix(err.Error(), "databricks: driver error: "+errArrowRowsConvertSchema+": "+errArrowRowsInvalidDecimalType)) @@ -390,7 +390,7 @@ func TestArrowRowScanner(t *testing.T) { schema = getAllTypesSchema() schema.Columns[13].TypeDesc.Types[0].PrimitiveEntry.TypeQualifiers.Qualifiers["scale"] = nil metadataResp.Schema = schema - _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) require.NotNil(t, err) assert.True(t, strings.HasPrefix(err.Error(), "databricks: driver error: "+errArrowRowsConvertSchema+": "+errArrowRowsInvalidDecimalType)) @@ -398,7 +398,7 @@ func TestArrowRowScanner(t *testing.T) { schema = getAllTypesSchema() schema.Columns[13].TypeDesc.Types[0].PrimitiveEntry.TypeQualifiers.Qualifiers["scale"].I32Value = nil metadataResp.Schema = schema - _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) require.NotNil(t, err) msg := err.Error() pre := "databricks: driver error: " + errArrowRowsConvertSchema + ": " + errArrowRowsInvalidDecimalType @@ -413,7 +413,7 @@ func TestArrowRowScanner(t *testing.T) { cfg := config.Config{} cfg.UseArrowBatches = true - d, err1 := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + d, err1 := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) require.Nil(t, err1) var ars *arrowRowScanner = d.(*arrowRowScanner) @@ -443,7 +443,7 @@ func TestArrowRowScanner(t *testing.T) { cfg := config.Config{} cfg.UseArrowBatches = true - d, err := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + d, err := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) require.Nil(t, err) d.Close() @@ -483,7 +483,7 @@ func TestArrowRowScanner(t *testing.T) { cfg := config.Config{} cfg.UseLz4Compression = false - d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) var ars *arrowRowScanner = d.(*arrowRowScanner) @@ -553,7 +553,7 @@ func TestArrowRowScanner(t *testing.T) { cfg := config.Config{} cfg.UseLz4Compression = false - d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, nil) + d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, nil, nil) var ars *arrowRowScanner = d.(*arrowRowScanner) @@ -592,7 +592,7 @@ func TestArrowRowScanner(t *testing.T) { cfg := config.Config{} cfg.UseLz4Compression = false - d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, nil) + d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, nil, nil) var ars *arrowRowScanner = d.(*arrowRowScanner) @@ -632,7 +632,7 @@ func TestArrowRowScanner(t *testing.T) { cfg := config.Config{} cfg.UseLz4Compression = false - d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, nil) + d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, nil, nil) var ars *arrowRowScanner = d.(*arrowRowScanner) @@ -675,7 +675,7 @@ func TestArrowRowScanner(t *testing.T) { cfg := config.Config{} cfg.UseLz4Compression = false - d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, nil) + d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, nil, nil) var ars *arrowRowScanner = d.(*arrowRowScanner) @@ -712,7 +712,7 @@ func TestArrowRowScanner(t *testing.T) { cfg := config.Config{} cfg.UseLz4Compression = false - d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) var ars *arrowRowScanner = d.(*arrowRowScanner) @@ -862,7 +862,7 @@ func TestArrowRowScanner(t *testing.T) { cfg := config.Config{} cfg.UseLz4Compression = false - d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) var ars *arrowRowScanner = d.(*arrowRowScanner) ars.UseArrowNativeComplexTypes = true @@ -940,7 +940,7 @@ func TestArrowRowScanner(t *testing.T) { config.UseArrowNativeComplexTypes = false config.UseArrowNativeDecimal = false config.UseArrowNativeIntervalTypes = false - d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) ars := d.(*arrowRowScanner) @@ -968,7 +968,7 @@ func TestArrowRowScanner(t *testing.T) { config := config.WithDefaults() config.UseArrowNativeComplexTypes = false - d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) ars := d.(*arrowRowScanner) @@ -991,7 +991,7 @@ func TestArrowRowScanner(t *testing.T) { config := config.WithDefaults() config.UseArrowNativeComplexTypes = false - d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) ars := d.(*arrowRowScanner) @@ -1023,7 +1023,7 @@ func TestArrowRowScanner(t *testing.T) { config := config.WithDefaults() config.UseArrowNativeComplexTypes = false - d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) ars := d.(*arrowRowScanner) @@ -1043,7 +1043,7 @@ func TestArrowRowScanner(t *testing.T) { config := config.WithDefaults() config.UseArrowNativeComplexTypes = false - d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) ars := d.(*arrowRowScanner) @@ -1119,7 +1119,7 @@ func TestArrowRowScanner(t *testing.T) { config.UseArrowNativeComplexTypes = false config.UseArrowNativeDecimal = false config.UseArrowNativeIntervalTypes = false - d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) ars := d.(*arrowRowScanner) @@ -1147,7 +1147,7 @@ func TestArrowRowScanner(t *testing.T) { config.UseArrowNativeComplexTypes = false config.UseArrowNativeDecimal = false config.UseArrowNativeIntervalTypes = false - d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) ars := d.(*arrowRowScanner) @@ -1199,7 +1199,7 @@ func TestArrowRowScanner(t *testing.T) { config.UseArrowNativeComplexTypes = true config.UseArrowNativeDecimal = false config.UseArrowNativeIntervalTypes = false - d, err1 := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err1 := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err1) ars := d.(*arrowRowScanner) @@ -1240,7 +1240,7 @@ func TestArrowRowScanner(t *testing.T) { config.UseArrowNativeComplexTypes = true config.UseArrowNativeDecimal = false config.UseArrowNativeIntervalTypes = false - d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) ars := d.(*arrowRowScanner) @@ -1299,7 +1299,7 @@ func TestArrowRowScanner(t *testing.T) { config.UseArrowNativeComplexTypes = true config.UseArrowNativeDecimal = false config.UseArrowNativeIntervalTypes = false - d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) ars := d.(*arrowRowScanner) @@ -1332,7 +1332,7 @@ func TestArrowRowScanner(t *testing.T) { config.UseArrowNativeComplexTypes = true config.UseArrowNativeDecimal = false config.UseArrowNativeIntervalTypes = false - d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) ars := d.(*arrowRowScanner) @@ -1371,7 +1371,7 @@ func TestArrowRowScanner(t *testing.T) { config.UseArrowNativeComplexTypes = true config.UseArrowNativeDecimal = false config.UseArrowNativeIntervalTypes = false - d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) ars := d.(*arrowRowScanner) @@ -1410,7 +1410,7 @@ func TestArrowRowScanner(t *testing.T) { config.UseArrowNativeComplexTypes = true config.UseArrowNativeDecimal = false config.UseArrowNativeIntervalTypes = false - d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) ars := d.(*arrowRowScanner) @@ -1446,7 +1446,7 @@ func TestArrowRowScanner(t *testing.T) { config.UseArrowNativeComplexTypes = true config.UseArrowNativeDecimal = false config.UseArrowNativeIntervalTypes = false - d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) ars := d.(*arrowRowScanner) @@ -1525,7 +1525,7 @@ func TestArrowRowScanner(t *testing.T) { config.UseArrowNativeComplexTypes = true config.UseArrowNativeDecimal = false config.UseArrowNativeIntervalTypes = false - _, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + _, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) }) @@ -1548,6 +1548,7 @@ func TestArrowRowScanner(t *testing.T) { 5000, nil, false, + true, client, "connectionId", "correlationId", @@ -1561,6 +1562,7 @@ func TestArrowRowScanner(t *testing.T) { cfg, logger, context.Background(), + nil, ) assert.Nil(t, err) diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index 0c6d3be..5c0f206 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -3,11 +3,13 @@ package arrowbased import ( "bytes" "context" + "database/sql/driver" "io" "time" "github.com/databricks/databricks-sql-go/internal/config" "github.com/databricks/databricks-sql-go/internal/rows/rowscanner" + "github.com/databricks/databricks-sql-go/logger" "github.com/pierrec/lz4/v4" "github.com/pkg/errors" @@ -41,7 +43,13 @@ func NewBatchIterator(batchLoader BatchLoader) (BatchIterator, dbsqlerr.DBError) return bi, nil } -func NewCloudBatchLoader(ctx context.Context, files []*cli_service.TSparkArrowResultLink, startRowOffset int64, cfg *config.Config) (*batchLoader[*cloudURL], dbsqlerr.DBError) { +func NewCloudBatchLoader( + ctx context.Context, + files []*cli_service.TSparkArrowResultLink, + startRowOffset int64, + cfg *config.Config, + pinger driver.Pinger, + logger *logger.DBSQLLogger) (*batchLoader[*cloudURL], dbsqlerr.DBError) { if cfg == nil { cfg = config.WithDefaults() @@ -68,7 +76,14 @@ func NewCloudBatchLoader(ctx context.Context, files []*cli_service.TSparkArrowRe // make sure to close input channel or fetcher will block waiting for more inputs close(inputChan) - f, _ := fetcher.NewConcurrentFetcher[*cloudURL](ctx, cfg.MaxDownloadThreads, cfg.MaxFilesInMemory, inputChan) + // If a Pinger was passed in create a heartbeat to keep the connection alive while cloud + // files are being downloaded. + var hb *heartBeat + if pinger != nil { + hb = &heartBeat{pinger: pinger, interval: cfg.HeartbeatInterval, logger: logger} + } + + f, _ := fetcher.NewConcurrentFetcher[*cloudURL](ctx, cfg.MaxDownloadThreads, cfg.MaxFilesInMemory, inputChan, hb) cbl := &batchLoader[*cloudURL]{ Delimiter: rowscanner.NewDelimiter(startRowOffset, rowCount), fetcher: f, @@ -103,7 +118,7 @@ func NewLocalBatchLoader(ctx context.Context, batches []*cli_service.TSparkArrow } close(inputChan) - f, _ := fetcher.NewConcurrentFetcher[*localBatch](ctx, cfg.MaxDownloadThreads, cfg.MaxFilesInMemory, inputChan) + f, _ := fetcher.NewConcurrentFetcher[*localBatch](ctx, cfg.MaxDownloadThreads, cfg.MaxFilesInMemory, inputChan, nil) cbl := &batchLoader[*localBatch]{ Delimiter: rowscanner.NewDelimiter(startRowOffset, rowCount), fetcher: f, @@ -120,6 +135,8 @@ type batchLoader[T interface { fetcher fetcher.Fetcher[SparkArrowBatch] arrowBatches []SparkArrowBatch ctx context.Context + cancelFetch context.CancelFunc + batchChan <-chan SparkArrowBatch } var _ BatchLoader = (*batchLoader[*localBatch])(nil) @@ -132,7 +149,10 @@ func (cbl *batchLoader[T]) GetBatchFor(rowNumber int64) (SparkArrowBatch, dbsqle } } - batchChan, _, err := cbl.fetcher.Start() + batchChan, cancelFetch, err := cbl.fetcher.Start() + cbl.batchChan = batchChan + cbl.cancelFetch = cancelFetch + var emptyBatch SparkArrowBatch if err != nil { return emptyBatch, dbsqlerrint.NewDriverError(cbl.ctx, errArrowRowsInvalidRowNumber(rowNumber), err) @@ -158,9 +178,20 @@ func (cbl *batchLoader[T]) GetBatchFor(rowNumber int64) (SparkArrowBatch, dbsqle } func (cbl *batchLoader[T]) Close() { + if cbl.cancelFetch != nil { + cbl.cancelFetch() + } + for i := range cbl.arrowBatches { cbl.arrowBatches[i].Close() } + + // drain any batches in the fetcher output channel + if cbl.batchChan != nil { + for b := range cbl.batchChan { + b.Close() + } + } } type compressibleBatch struct { @@ -318,3 +349,59 @@ func (bi *batchIterator) Close() { bi.batchLoader.Close() } } + +// Once started heartBeat will call a StatusGetter at +// a regular interval until it is stopped. +type heartBeat struct { + pinger driver.Pinger + stopChan chan bool + interval time.Duration + running bool + logger *logger.DBSQLLogger + err error + beatCount int +} + +var _ fetcher.Overwatch = (*heartBeat)(nil) + +func (hb *heartBeat) Start() { + hb.logger.Debug().Msg("heartbeat: starting") + hb.running = true + if hb.stopChan == nil { + hb.stopChan = make(chan bool) + } + + go func() { + it := time.NewTimer(hb.interval) + defer it.Stop() + + for { + select { + case <-it.C: + hb.beatCount += 1 + err := hb.pinger.Ping(context.Background()) + if err != nil { + hb.logger.Debug().Msg("heartbeat: ping failed") + hb.running = false + hb.err = err + return + } + hb.logger.Debug().Msg("heartbeat: ping success") + it.Reset(hb.interval) + + case <-hb.stopChan: + hb.running = false + hb.logger.Debug().Msg("heartbeat: stopping") + return + } + } + }() +} + +func (hb *heartBeat) Stop() { + if hb.stopChan == nil { + hb.stopChan = make(chan bool) + } + + close(hb.stopChan) +} diff --git a/internal/rows/arrowbased/batchloader_test.go b/internal/rows/arrowbased/batchloader_test.go index 35bad33..6859e90 100644 --- a/internal/rows/arrowbased/batchloader_test.go +++ b/internal/rows/arrowbased/batchloader_test.go @@ -15,8 +15,11 @@ import ( "github.com/apache/arrow/go/v12/arrow/ipc" "github.com/apache/arrow/go/v12/arrow/memory" dbsqlerr "github.com/databricks/databricks-sql-go/errors" + "github.com/databricks/databricks-sql-go/internal/cli_service" + "github.com/databricks/databricks-sql-go/internal/config" dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors" "github.com/databricks/databricks-sql-go/internal/rows/rowscanner" + "github.com/databricks/databricks-sql-go/logger" "github.com/pkg/errors" "github.com/stretchr/testify/assert" ) @@ -127,6 +130,202 @@ func TestCloudURLFetch(t *testing.T) { } } +func TestBatchLoader(t *testing.T) { + + t.Run("test loading", func(t *testing.T) { + var nLoads int + var handler func(w http.ResponseWriter, r *http.Request) = func(w http.ResponseWriter, r *http.Request) { + nLoads += 1 + w.WriteHeader(http.StatusOK) + _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) + if err != nil { + panic(err) + } + + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler(w, r) + })) + defer server.Close() + + expiryTime := time.Now().Add(10 * time.Second) + + urls := []*cli_service.TSparkArrowResultLink{ + { + FileLink: server.URL, + ExpiryTime: expiryTime.Unix(), + StartRowOffset: 0, + RowCount: 3, + }, + { + FileLink: server.URL, + ExpiryTime: expiryTime.Unix(), + StartRowOffset: 3, + RowCount: 3, + }, + { + FileLink: server.URL, + ExpiryTime: expiryTime.Unix(), + StartRowOffset: 6, + RowCount: 3, + }, + } + + bl, err := NewCloudBatchLoader( + context.Background(), + urls, + 0, config.WithDefaults(), + &testStatusGetter{}, + logger.Logger, + ) + assert.Nil(t, err) + + for i := range urls { + batch, err := bl.GetBatchFor(urls[i].StartRowOffset + 1) + assert.Nil(t, err) + assert.NotNil(t, batch) + assert.Equal(t, urls[i].RowCount, batch.Count()) + assert.Equal(t, urls[i].StartRowOffset, batch.Start()) + } + + assert.Equal(t, len(urls), nLoads) + }) + + t.Run("test link load failure", func(t *testing.T) { + var nLoads int + var handler func(w http.ResponseWriter, r *http.Request) = func(w http.ResponseWriter, r *http.Request) { + nLoads += 1 + if nLoads == 3 { + w.WriteHeader(http.StatusInternalServerError) + } else { + w.WriteHeader(http.StatusOK) + _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) + if err != nil { + panic(err) + } + } + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler(w, r) + })) + defer server.Close() + + expiryTime := time.Now().Add(10 * time.Second) + + urls := []*cli_service.TSparkArrowResultLink{ + { + FileLink: server.URL, + ExpiryTime: expiryTime.Unix(), + StartRowOffset: 0, + RowCount: 3, + }, + { + FileLink: server.URL, + ExpiryTime: expiryTime.Unix(), + StartRowOffset: 3, + RowCount: 3, + }, + { + FileLink: server.URL, + ExpiryTime: expiryTime.Unix(), + StartRowOffset: 6, + RowCount: 3, + }, + { + FileLink: server.URL, + ExpiryTime: expiryTime.Unix(), + StartRowOffset: 9, + RowCount: 3, + }, + } + + cfg := config.WithDefaults() + cfg.MaxDownloadThreads = 1 + + bl, err := NewCloudBatchLoader( + context.Background(), + urls, + 0, + cfg, + &testStatusGetter{}, + logger.Logger, + ) + assert.Nil(t, err) + + batch, err := bl.GetBatchFor(0) + assert.Nil(t, err) + assert.NotNil(t, batch) + + time.Sleep(1 * time.Second) + + for _, i := range []int{0, 1} { + batch, err := bl.GetBatchFor(urls[i].StartRowOffset + 1) + assert.Nil(t, err) + assert.NotNil(t, batch) + assert.Equal(t, urls[i].RowCount, batch.Count()) + assert.Equal(t, urls[i].StartRowOffset, batch.Start()) + } + + for _, i := range []int{2, 3} { + _, err := bl.GetBatchFor(urls[i].StartRowOffset + 1) + assert.NotNil(t, err) + + } + + }) +} + +func TestHeartBeat(t *testing.T) { + + t.Run("test heartbeat", func(t *testing.T) { + sg := &testStatusGetter{} + + hb := &heartBeat{pinger: sg, interval: 100 * time.Millisecond, logger: logger.Logger} + hb.Start() + time.Sleep(1 * time.Second) + assert.True(t, hb.running) + hb.Stop() + time.Sleep(1 * time.Second) + assert.False(t, hb.running) + assert.Nil(t, hb.err) + assert.GreaterOrEqual(t, hb.beatCount, 8) + assert.LessOrEqual(t, hb.beatCount, 12) + }) + + t.Run("stop on error", func(t *testing.T) { + var beatCount int + f := func() error { + beatCount += 1 + if beatCount == 5 { + return errors.New("get status error") + } + return nil + } + sg := &testStatusGetter{f: f} + + hb := &heartBeat{pinger: sg, interval: 100 * time.Millisecond, logger: logger.Logger} + hb.Start() + time.Sleep(1 * time.Second) + + assert.False(t, hb.running) + assert.EqualError(t, hb.err, "get status error") + + }) +} + +type testStatusGetter struct { + f func() error +} + +func (sg *testStatusGetter) Ping(ctx context.Context) error { + if sg.f != nil { + return sg.f() + } + return nil +} + func generateArrowRecord() arrow.Record { mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) diff --git a/internal/rows/rows.go b/internal/rows/rows.go index c9581e2..c7933c6 100644 --- a/internal/rows/rows.go +++ b/internal/rows/rows.go @@ -58,6 +58,8 @@ type rows struct { logger_ *dbsqllog.DBSQLLogger ctx context.Context + + pinger driver.Pinger } var _ driver.Rows = (*rows)(nil) @@ -101,6 +103,8 @@ func NewRows( } } + pinger := newStatusPinger(ctx, client, opHandle, logger) + logger.Debug().Msgf("databricks: creating Rows, pageSize: %d, location: %v", pageSize, location) r := &rows{ @@ -112,8 +116,11 @@ func NewRows( config: config, logger_: logger, ctx: ctx, + pinger: pinger, } + var closedOnServer, hasMoreRows bool + hasMoreRows = true // if we already have results for the query do some additional initialization if directResults != nil { logger.Debug().Msgf("databricks: creating Rows with direct results") @@ -128,6 +135,11 @@ func NewRows( if err != nil { return r, err } + + closedOnServer = directResults.CloseOperation != nil + if directResults.ResultSet != nil { + hasMoreRows = directResults.ResultSet.GetHasMoreRows() + } } var d rowscanner.Delimiter @@ -139,12 +151,13 @@ func NewRows( // If the entire query result set fits in direct results the server closes // the operations. - closedOnServer := directResults != nil && directResults.CloseOperation != nil + r.ResultPageIterator = rowscanner.NewResultPageIterator( d, pageSize, opHandle, closedOnServer, + hasMoreRows, client, connId, correlationId, @@ -495,13 +508,12 @@ func (r *rows) makeRowScanner(fetchResults *cli_service.TFetchResultsResp) dbsql var rs rowscanner.RowScanner var err dbsqlerr.DBError if fetchResults.Results != nil { - if fetchResults.Results.Columns != nil { rs, err = columnbased.NewColumnRowScanner(schema, fetchResults.Results, r.config, r.logger(), r.ctx) } else if fetchResults.Results.ArrowBatches != nil { - rs, err = arrowbased.NewArrowRowScanner(r.resultSetMetadata, fetchResults.Results, r.config, r.logger(), r.ctx) + rs, err = arrowbased.NewArrowRowScanner(r.resultSetMetadata, fetchResults.Results, r.config, r.logger(), r.ctx, r.pinger) } else if fetchResults.Results.ResultLinks != nil { - rs, err = arrowbased.NewArrowRowScanner(r.resultSetMetadata, fetchResults.Results, r.config, r.logger(), r.ctx) + rs, err = arrowbased.NewArrowRowScanner(r.resultSetMetadata, fetchResults.Results, r.config, r.logger(), r.ctx, r.pinger) } else { r.logger().Error().Msg(errRowsUnknowRowType) err = dbsqlerr_int.NewDriverError(r.ctx, errRowsUnknowRowType, nil) @@ -542,5 +554,26 @@ func (r *rows) GetArrowBatches(ctx context.Context) (dbsqlrows.ArrowBatchIterato return r.RowScanner.GetArrowBatches(ctx, *r.config, r.ResultPageIterator) } - return arrowbased.NewArrowRecordIterator(ctx, r.ResultPageIterator, nil, nil, *r.config), nil + return arrowbased.NewArrowRecordIterator(ctx, r.ResultPageIterator, nil, nil, *r.config, r.pinger, r.logger()), nil +} + +// statusGetter implements driver.Ping by querying the operation status +type statusGetter struct { + opHandle *cli_service.TOperationHandle + client cli_service.TCLIService + ctx context.Context + logger *dbsqllog.DBSQLLogger +} + +func (sg *statusGetter) Ping(ctx context.Context) error { + _, err := sg.client.GetOperationStatus(sg.ctx, &cli_service.TGetOperationStatusReq{OperationHandle: sg.opHandle}) + if err != nil { + sg.logger.Err(err).Msg("databricks: status getter failed GetOperationStatus") + } + + return err +} + +func newStatusPinger(ctx context.Context, client cli_service.TCLIService, opHandle *cli_service.TOperationHandle, logger *dbsqllog.DBSQLLogger) driver.Pinger { + return &statusGetter{opHandle: opHandle, client: client, ctx: ctx, logger: logger} } diff --git a/internal/rows/rows_test.go b/internal/rows/rows_test.go index c83fc41..1f09ad1 100644 --- a/internal/rows/rows_test.go +++ b/internal/rows/rows_test.go @@ -222,6 +222,7 @@ func TestRowsFetchResultPageNoDirectResults(t *testing.T) { 1000, nil, false, + true, client, "connId", "corrId", @@ -316,6 +317,7 @@ func TestRowsFetchResultPageWithDirectResults(t *testing.T) { 1000, nil, false, + true, client, "connId", "corrId", @@ -465,6 +467,7 @@ func TestNextNoDirectResults(t *testing.T) { 1000, nil, false, + true, client, "connId", "corrId", @@ -760,6 +763,7 @@ func TestFetchResultsWithRetries(t *testing.T) { 1000, nil, false, + true, client, "connId", "corrId", diff --git a/internal/rows/rowscanner/resultPageIterator.go b/internal/rows/rowscanner/resultPageIterator.go index fc697f7..fb6ca1c 100644 --- a/internal/rows/rowscanner/resultPageIterator.go +++ b/internal/rows/rowscanner/resultPageIterator.go @@ -49,6 +49,7 @@ func NewResultPageIterator( maxPageSize int64, opHandle *cli_service.TOperationHandle, closedOnServer bool, + hasMoreRows bool, client cli_service.TCLIService, connectionId string, correlationId string, @@ -56,10 +57,12 @@ func NewResultPageIterator( ) ResultPageIterator { // delimiter and hasMoreRows are used to set up the point in the paginated - // result set that this iterator starts from. - return &resultPageIterator{ + // result set that this iterator starts from. It is possible to have a + // case where hasMoreRows is false but the operation is not yet closed on + // the server so we have separate flags. + rpf := &resultPageIterator{ Delimiter: delimiter, - isFinished: closedOnServer, + isFinished: !hasMoreRows, maxPageSize: maxPageSize, opHandle: opHandle, closedOnServer: closedOnServer, @@ -68,6 +71,8 @@ func NewResultPageIterator( correlationId: correlationId, logger: logger, } + + return rpf } type resultPageIterator struct { @@ -168,6 +173,7 @@ func (rpf *resultPageIterator) Close() (err error) { // need to do that now if !rpf.closedOnServer { rpf.closedOnServer = true + rpf.isFinished = true req := cli_service.TCloseOperationReq{ OperationHandle: rpf.opHandle,