diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 8a2e685..7735226 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -1,4 +1,4 @@ -name: Go +name: CI on: push: @@ -19,14 +19,25 @@ jobs: uses: golangci/golangci-lint-action@v2 with: version: v1.48 + build-and-test: name: Test and Build + permissions: + contents: read + packages: read strategy: matrix: go-version: [1.19.x] os: [ubuntu-latest] runs-on: ${{ matrix.os }} - + services: + test-server: + credentials: + username: ${{ github.repository }} + password: ${{ secrets.GITHUB_TOKEN }} + image: ghcr.io/databricks/databricks/databricks-thrift-test-server:0.1.0 + ports: + - 8087:8087 steps: - name: Check out code into the Go module directory uses: actions/checkout@v3 @@ -60,10 +71,15 @@ jobs: fi go get -v -t -d ./... + - name: Build + run: make linux + - name: Test run: make test env: CGO_ENABLED: 0 - - name: Build - run: make linux + - name: Contract Test + run: make test-integration + env: + CGO_ENABLED: 0 \ No newline at end of file diff --git a/Makefile b/Makefile index 9c81d88..c7291e7 100644 --- a/Makefile +++ b/Makefile @@ -57,6 +57,12 @@ test: bin/gotestsum ## Run the go unit tests. @echo "INFO: Running all go unit tests." CGO_ENABLED=0 ./bin/gotestsum --format pkgname-and-test-fails --junitfile $(TEST_RESULTS_DIR)/unit-tests.xml ./... +.PHONY: test-integration +test-integration: bin/gotestsum ## Run the go unit integration tests. + @echo "INFO: Running all go integration tests." + CGO_ENABLED=0 ./bin/gotestsum --format pkgname-and-test-fails -- -tags=integration ./... + + .PHONY: coverage coverage: bin/gotestsum ## Report the unit test code coverage. @echo "INFO: Generating unit test coverage report." diff --git a/connector.go b/connector.go index 7b6935c..e29c124 100644 --- a/connector.go +++ b/connector.go @@ -21,12 +21,20 @@ type connector struct { func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { var catalogName *cli_service.TIdentifier var schemaName *cli_service.TIdentifier + var initialNamespace *cli_service.TNamespace + if c.cfg.Catalog != "" { catalogName = cli_service.TIdentifierPtr(cli_service.TIdentifier(c.cfg.Catalog)) } if c.cfg.Schema != "" { schemaName = cli_service.TIdentifierPtr(cli_service.TIdentifier(c.cfg.Schema)) } + if catalogName != nil || schemaName != nil { + initialNamespace = &cli_service.TNamespace{ + CatalogName: catalogName, + SchemaName: schemaName, + } + } tclient, err := client.InitThriftClient(c.cfg) if err != nil { @@ -34,13 +42,11 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { } // we need to ensure that open session will eventually end + session, err := tclient.OpenSession(ctx, &cli_service.TOpenSessionReq{ - ClientProtocol: c.cfg.ThriftProtocolVersion, - Configuration: make(map[string]string), - InitialNamespace: &cli_service.TNamespace{ - CatalogName: catalogName, - SchemaName: schemaName, - }, + ClientProtocol: c.cfg.ThriftProtocolVersion, + Configuration: make(map[string]string), + InitialNamespace: initialNamespace, CanUseMultipleCatalogs: &c.cfg.CanUseMultipleCatalogs, }) diff --git a/integration_tests/integration.go b/integration_tests/integration.go new file mode 100644 index 0000000..6108072 --- /dev/null +++ b/integration_tests/integration.go @@ -0,0 +1 @@ +package integration_tests diff --git a/integration_tests/integration_test.go b/integration_tests/integration_test.go new file mode 100644 index 0000000..925fa13 --- /dev/null +++ b/integration_tests/integration_test.go @@ -0,0 +1,75 @@ +//go:build integration + +package integration_tests + +import ( + "context" + "database/sql" + "testing" + + dbsql "github.com/databricks/databricks-sql-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIntegrationPing(t *testing.T) { + + connector, err := dbsql.NewConnector( + dbsql.WithServerHostname("localhost"), + dbsql.WithPort(8087), + dbsql.WithHTTPPath("session"), //test case name + dbsql.WithMaxRows(100), + ) + require.NoError(t, err) + db := sql.OpenDB(connector) + defer db.Close() + + t.Run("it can query multiple result pages", func(t *testing.T) { + rows, err1 := db.QueryContext(context.Background(), `select * from default.diamonds limit 250`) + require.Nil(t, err1) + require.NotNil(t, rows) + defer rows.Close() + type row struct { + _c0 int + carat float64 + cut string + color string + clarity string + depth sql.NullFloat64 + table sql.NullFloat64 + price int + x float64 + y float64 + z float64 + } + expectedColumnNames := []string{"_c0", "carat", "cut", "color", "clarity", "depth", "table", "price", "x", "y", "z"} + expectedDatabaseType := []string{"INT", "DOUBLE", "STRING", "STRING", "STRING", "DOUBLE", "DOUBLE", "INT", "DOUBLE", "DOUBLE", "DOUBLE"} + expectedNullable := []bool{false, false, false, false, false, false, false, false, false, false, false} + + cols, err := rows.Columns() + require.NoError(t, err) + require.Equal(t, expectedColumnNames, cols) + + types, err := rows.ColumnTypes() + require.NoError(t, err) + + for i, v := range types { + assert.Equal(t, expectedColumnNames[i], v.Name()) + assert.Equal(t, expectedDatabaseType[i], v.DatabaseTypeName()) + nullable, ok := v.Nullable() + assert.False(t, ok) + assert.Equal(t, expectedNullable[i], nullable) + } + var allrows []row + for rows.Next() { + // After row 10 this will cause one fetch call, as 10 rows (maxRows config) will come from the first execute statement call. + r := row{} + err := rows.Scan(&r._c0, &r.carat, &r.cut, &r.color, &r.clarity, &r.depth, &r.table, &r.price, &r.x, &r.y, &r.z) + assert.Nil(t, err) + allrows = append(allrows, r) + } + assert.Equal(t, 250, len(allrows)) + assert.Nil(t, rows.Err()) + + }) +}