From 50cb8c7830478a4de12d0c809a1f63550fa3cc7b Mon Sep 17 00:00:00 2001 From: eakmanrq <6326532+eakmanrq@users.noreply.github.com> Date: Fri, 22 Aug 2025 09:08:22 -0700 Subject: [PATCH] feat: widen dbt-core compatibility range --- .github/workflows/pr.yaml | 60 ++++++++++++++++++++++++++++++++ Makefile | 27 +++++++++----- examples/sushi_dbt/profiles.yml | 8 +++++ pyproject.toml | 2 +- sqlmesh/dbt/loader.py | 7 ++-- sqlmesh/dbt/manifest.py | 32 +++++++++++++---- sqlmesh/dbt/relation.py | 6 ++-- sqlmesh/dbt/seed.py | 48 +++++++++++++------------ sqlmesh/dbt/util.py | 6 ++-- tests/dbt/conftest.py | 14 ++++++++ tests/dbt/test_adapter.py | 7 ++++ tests/dbt/test_config.py | 35 ++++++++++++------- tests/dbt/test_integration.py | 10 +++++- tests/dbt/test_manifest.py | 21 ++++++++--- tests/dbt/test_model.py | 6 +++- tests/dbt/test_transformation.py | 35 ++++++++++++++----- 16 files changed, 249 insertions(+), 75 deletions(-) diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index b63f6a3ab6..3e715e1318 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -8,6 +8,8 @@ on: concurrency: group: 'pr-${{ github.event.pull_request.number }}' cancel-in-progress: true +permissions: + contents: read jobs: test-vscode: env: @@ -66,3 +68,61 @@ jobs: name: playwright-report path: vscode/extension/playwright-report/ retention-days: 30 + test-dbt-versions: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + dbt-version: + [ + '1.3.0', + '1.4.0', + '1.5.0', + '1.6.0', + '1.7.0', + '1.8.0', + '1.9.0', + '1.10.0', + ] + steps: + - uses: actions/checkout@v5 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + - name: Install uv + uses: astral-sh/setup-uv@v6 + - name: Install SQLMesh dev dependencies + run: | + uv venv .venv + source .venv/bin/activate + sed -i 's/"pydantic>=2.0.0"/"pydantic"/g' pyproject.toml + if [[ "${{ matrix.dbt-version }}" == "1.10.0" ]]; then + # For 1.10.0: only add version to dbt-core, remove versions from all adapter packages + sed -i -E 's/"(dbt-core)[^"]*"/"\1~=${{ matrix.dbt-version }}"/g' pyproject.toml + # Remove version constraints from all dbt adapter packages + sed -i -E 's/"(dbt-(bigquery|duckdb|snowflake|athena-community|clickhouse|databricks|redshift|trino))[^"]*"/"\1"/g' pyproject.toml + else + # For other versions: apply version to all dbt packages + sed -i -E 's/"(dbt-[^">=<~!]+)[^"]*"/"\1~=${{ matrix.dbt-version }}"/g' pyproject.toml + fi + UV=1 make install-dev + uv pip install pydantic>=2.0.0 --reinstall + - name: Run dbt tests + # We can't run slow tests across all engines due to tests requiring DuckDB and old versions + # of DuckDB require a version of DuckDB we no longer support + run: | + source .venv/bin/activate + make dbt-fast-test + - name: Test SQLMesh info in sushi_dbt + working-directory: ./examples/sushi_dbt + run: | + source ../../.venv/bin/activate + sed -i 's/target: in_memory/target: postgres/g' profiles.yml + if [[ $(echo -e "${{ matrix.dbt-version }}\n1.5.0" | sort -V | head -n1) == "${{ matrix.dbt-version }}" ]] && [[ "${{ matrix.dbt-version }}" != "1.5.0" ]]; then + echo "DBT version is ${{ matrix.dbt-version }} (< 1.5.0), removing version parameters..." + sed -i -e 's/, version=1) }}/) }}/g' -e 's/, v=1) }}/) }}/g' models/top_waiters.sql + else + echo "DBT version is ${{ matrix.dbt-version }} (>= 1.5.0), keeping version parameters" + fi + sqlmesh info --skip-connection diff --git a/Makefile b/Makefile index bad2cf2907..04306946cd 100644 --- a/Makefile +++ b/Makefile @@ -1,10 +1,16 @@ .PHONY: docs +ifdef UV + PIP := uv pip +else + PIP := pip3 +endif + install-dev: - pip3 install -e ".[dev,web,slack,dlt,lsp]" ./examples/custom_materializations + $(PIP) install -e ".[dev,web,slack,dlt,lsp]" ./examples/custom_materializations install-doc: - pip3 install -r ./docs/requirements.txt + $(PIP) install -r ./docs/requirements.txt install-pre-commit: pre-commit install @@ -22,16 +28,16 @@ doc-test: python -m pytest --doctest-modules sqlmesh/core sqlmesh/utils package: - pip3 install build && python3 -m build + $(PIP) install build && python3 -m build publish: package - pip3 install twine && python3 -m twine upload dist/* + $(PIP) install twine && python3 -m twine upload dist/* package-tests: - pip3 install build && cp pyproject.toml tests/sqlmesh_pyproject.toml && python3 -m build tests/ + $(PIP) install build && cp pyproject.toml tests/sqlmesh_pyproject.toml && python3 -m build tests/ publish-tests: package-tests - pip3 install twine && python3 -m twine upload -r tobiko-private tests/dist/* + $(PIP) install twine && python3 -m twine upload -r tobiko-private tests/dist/* docs-serve: mkdocs serve @@ -93,6 +99,9 @@ engine-test: dbt-test: pytest -n auto -m "dbt and not cicdonly" +dbt-fast-test: + pytest -n auto -m "dbt and fast" --retries 3 + github-test: pytest -n auto -m "github" @@ -109,7 +118,7 @@ guard-%: fi engine-%-install: - pip3 install -e ".[dev,web,slack,lsp,${*}]" ./examples/custom_materializations + $(PIP) install -e ".[dev,web,slack,lsp,${*}]" ./examples/custom_materializations engine-docker-%-up: docker compose -f ./tests/core/engine_adapter/integration/docker/compose.${*}.yaml up -d @@ -157,11 +166,11 @@ snowflake-test: guard-SNOWFLAKE_ACCOUNT guard-SNOWFLAKE_WAREHOUSE guard-SNOWFLAK pytest -n auto -m "snowflake" --retries 3 --junitxml=test-results/junit-snowflake.xml bigquery-test: guard-BIGQUERY_KEYFILE engine-bigquery-install - pip install -e ".[bigframes]" + $(PIP) install -e ".[bigframes]" pytest -n auto -m "bigquery" --retries 3 --junitxml=test-results/junit-bigquery.xml databricks-test: guard-DATABRICKS_CATALOG guard-DATABRICKS_SERVER_HOSTNAME guard-DATABRICKS_HTTP_PATH guard-DATABRICKS_ACCESS_TOKEN guard-DATABRICKS_CONNECT_VERSION engine-databricks-install - pip install 'databricks-connect==${DATABRICKS_CONNECT_VERSION}' + $(PIP) install 'databricks-connect==${DATABRICKS_CONNECT_VERSION}' pytest -n auto -m "databricks" --retries 3 --junitxml=test-results/junit-databricks.xml redshift-test: guard-REDSHIFT_HOST guard-REDSHIFT_USER guard-REDSHIFT_PASSWORD guard-REDSHIFT_DATABASE engine-redshift-install diff --git a/examples/sushi_dbt/profiles.yml b/examples/sushi_dbt/profiles.yml index 74de4e472c..794b083793 100644 --- a/examples/sushi_dbt/profiles.yml +++ b/examples/sushi_dbt/profiles.yml @@ -3,6 +3,14 @@ sushi: in_memory: type: duckdb schema: sushi + postgres: + type: postgres + host: "host" + user: "user" + password: "password" + dbname: "dbname" + port: 5432 + schema: sushi duckdb: type: duckdb path: 'local.duckdb' diff --git a/pyproject.toml b/pyproject.toml index e125bfb281..7e532f75f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ bigframes = ["bigframes>=1.32.0"] clickhouse = ["clickhouse-connect"] databricks = ["databricks-sql-connector[pyarrow]"] dev = [ - "agate==1.7.1", + "agate", "beautifulsoup4", "clickhouse-connect", "cryptography", diff --git a/sqlmesh/dbt/loader.py b/sqlmesh/dbt/loader.py index d321246896..594c5a8807 100644 --- a/sqlmesh/dbt/loader.py +++ b/sqlmesh/dbt/loader.py @@ -188,8 +188,11 @@ def _load_projects(self) -> t.List[Project]: self._projects.append(project) - if project.context.target.database != (self.context.default_catalog or ""): - raise ConfigError("Project default catalog does not match context default catalog") + context_default_catalog = self.context.default_catalog or "" + if project.context.target.database != context_default_catalog: + raise ConfigError( + f"Project default catalog ('{project.context.target.database}') does not match context default catalog ('{context_default_catalog}')." + ) for path in project.project_files: self._track_file(path) diff --git a/sqlmesh/dbt/manifest.py b/sqlmesh/dbt/manifest.py index 85a8f7205e..690bca4a3a 100644 --- a/sqlmesh/dbt/manifest.py +++ b/sqlmesh/dbt/manifest.py @@ -13,10 +13,17 @@ from dbt import constants as dbt_constants, flags +from sqlmesh.dbt.util import DBT_VERSION from sqlmesh.utils.conversions import make_serializable # Override the file name to prevent dbt commands from invalidating the cache. -dbt_constants.PARTIAL_PARSE_FILE_NAME = "sqlmesh_partial_parse.msgpack" + +if DBT_VERSION >= (1, 6, 0): + dbt_constants.PARTIAL_PARSE_FILE_NAME = "sqlmesh_partial_parse.msgpack" # type: ignore +else: + from dbt.parser import manifest as dbt_manifest # type: ignore + + dbt_manifest.PARTIAL_PARSE_FILE_NAME = "sqlmesh_partial_parse.msgpack" # type: ignore import jinja2 from dbt.adapters.factory import register_adapter, reset_adapters @@ -379,11 +386,17 @@ def _load_on_run_start_end(self) -> None: if "on-run-start" in node.tags: self._on_run_start_per_package[node.package_name][node_name] = HookConfig( - sql=sql, index=node.index or 0, path=node_path, dependencies=dependencies + sql=sql, + index=getattr(node, "index", None) or 0, + path=node_path, + dependencies=dependencies, ) else: self._on_run_end_per_package[node.package_name][node_name] = HookConfig( - sql=sql, index=node.index or 0, path=node_path, dependencies=dependencies + sql=sql, + index=getattr(node, "index", None) or 0, + path=node_path, + dependencies=dependencies, ) @property @@ -599,6 +612,9 @@ def _macro_references( manifest: Manifest, node: t.Union[ManifestNode, Macro] ) -> t.Set[MacroReference]: result: t.Set[MacroReference] = set() + if not hasattr(node, "depends_on"): + return result + for macro_node_id in node.depends_on.macros: if not macro_node_id: continue @@ -614,18 +630,20 @@ def _macro_references( def _refs(node: ManifestNode) -> t.Set[str]: if DBT_VERSION >= (1, 5, 0): - result = set() + result: t.Set[str] = set() + if not hasattr(node, "refs"): + return result for r in node.refs: - ref_name = f"{r.package}.{r.name}" if r.package else r.name + ref_name = f"{r.package}.{r.name}" if r.package else r.name # type: ignore if getattr(r, "version", None): - ref_name = f"{ref_name}_v{r.version}" + ref_name = f"{ref_name}_v{r.version}" # type: ignore result.add(ref_name) return result return {".".join(r) for r in node.refs} # type: ignore def _sources(node: ManifestNode) -> t.Set[str]: - return {".".join(s) for s in node.sources} + return {".".join(s) for s in getattr(node, "sources", [])} def _model_node_id(model_name: str, package: str) -> str: diff --git a/sqlmesh/dbt/relation.py b/sqlmesh/dbt/relation.py index f68a9ff6de..fff9f75593 100644 --- a/sqlmesh/dbt/relation.py +++ b/sqlmesh/dbt/relation.py @@ -1,7 +1,7 @@ from sqlmesh.dbt.util import DBT_VERSION -if DBT_VERSION < (1, 8, 0): - from dbt.contracts.relation import * # type: ignore # noqa: F403 -else: +if DBT_VERSION >= (1, 8, 0): from dbt.adapters.contracts.relation import * # type: ignore # noqa: F403 +else: + from dbt.contracts.relation import * # type: ignore # noqa: F403 diff --git a/sqlmesh/dbt/seed.py b/sqlmesh/dbt/seed.py index a84e39e653..cf22d961cf 100644 --- a/sqlmesh/dbt/seed.py +++ b/sqlmesh/dbt/seed.py @@ -5,11 +5,13 @@ import agate -try: +from sqlmesh.dbt.util import DBT_VERSION + +if DBT_VERSION >= (1, 8, 0): from dbt_common.clients import agate_helper # type: ignore SUPPORTS_DELIMITER = True -except ImportError: +else: from dbt.clients import agate_helper # type: ignore SUPPORTS_DELIMITER = False @@ -95,27 +97,7 @@ def to_sqlmesh( ) -class Integer(agate_helper.Integer): - def cast(self, d: t.Any) -> t.Optional[int]: - if isinstance(d, str): - # The dbt's implementation doesn't support coercion of strings to integers. - if d.strip().lower() in self.null_values: - return None - try: - return int(d) - except ValueError: - raise agate.exceptions.CastError('Can not parse value "%s" as Integer.' % d) - return super().cast(d) - - def jsonify(self, d: t.Any) -> str: - return d - - -agate_helper.Integer = Integer # type: ignore - - AGATE_TYPE_MAPPING = { - agate_helper.Integer: exp.DataType.build("int"), agate_helper.Number: exp.DataType.build("double"), agate_helper.ISODateTime: exp.DataType.build("datetime"), agate.Date: exp.DataType.build("date"), @@ -123,3 +105,25 @@ def jsonify(self, d: t.Any) -> str: agate.Boolean: exp.DataType.build("boolean"), agate.Text: exp.DataType.build("text"), } + + +if DBT_VERSION >= (1, 7, 0): + + class Integer(agate_helper.Integer): + def cast(self, d: t.Any) -> t.Optional[int]: + if isinstance(d, str): + # The dbt's implementation doesn't support coercion of strings to integers. + if d.strip().lower() in self.null_values: + return None + try: + return int(d) + except ValueError: + raise agate.exceptions.CastError('Can not parse value "%s" as Integer.' % d) + return super().cast(d) + + def jsonify(self, d: t.Any) -> str: + return d + + agate_helper.Integer = Integer # type: ignore + + AGATE_TYPE_MAPPING[agate_helper.Integer] = exp.DataType.build("int") diff --git a/sqlmesh/dbt/util.py b/sqlmesh/dbt/util.py index 9ffca39167..0de16e3b3e 100644 --- a/sqlmesh/dbt/util.py +++ b/sqlmesh/dbt/util.py @@ -20,10 +20,10 @@ def _get_dbt_version() -> t.Tuple[int, int, int]: DBT_VERSION = _get_dbt_version() -if DBT_VERSION < (1, 8, 0): - from dbt.clients.agate_helper import table_from_data_flat, empty_table, as_matrix # type: ignore # noqa: F401 -else: +if DBT_VERSION >= (1, 8, 0): from dbt_common.clients.agate_helper import table_from_data_flat, empty_table, as_matrix # type: ignore # noqa: F401 +else: + from dbt.clients.agate_helper import table_from_data_flat, empty_table, as_matrix # type: ignore # noqa: F401 def pandas_to_agate(df: pd.DataFrame) -> agate.Table: diff --git a/tests/dbt/conftest.py b/tests/dbt/conftest.py index 1852a873c8..5875d9f575 100644 --- a/tests/dbt/conftest.py +++ b/tests/dbt/conftest.py @@ -7,6 +7,7 @@ from sqlmesh.core.context import Context from sqlmesh.dbt.context import DbtContext from sqlmesh.dbt.project import Project +from sqlmesh.dbt.target import PostgresConfig @pytest.fixture() @@ -25,3 +26,16 @@ def render(value: str) -> str: return render return create_renderer + + +@pytest.fixture() +def dbt_dummy_postgres_config() -> PostgresConfig: + return PostgresConfig( # type: ignore + name="postgres", + host="host", + user="user", + password="password", + dbname="dbname", + port=5432, + schema="schema", + ) diff --git a/tests/dbt/test_adapter.py b/tests/dbt/test_adapter.py index 445e5f29c0..85dfa29559 100644 --- a/tests/dbt/test_adapter.py +++ b/tests/dbt/test_adapter.py @@ -23,6 +23,7 @@ pytestmark = pytest.mark.dbt +@pytest.mark.slow def test_adapter_relation(sushi_test_project: Project, runtime_renderer: t.Callable): context = sushi_test_project.context assert context.target @@ -96,6 +97,7 @@ def test_adapter_relation(sushi_test_project: Project, runtime_renderer: t.Calla assert engine_adapter.table_exists("foo.bar__backup") +@pytest.mark.slow def test_bigquery_get_columns_in_relation( sushi_test_project: Project, runtime_renderer: t.Callable, @@ -135,6 +137,7 @@ def test_bigquery_get_columns_in_relation( @pytest.mark.cicdonly +@pytest.mark.slow def test_normalization( sushi_test_project: Project, runtime_renderer: t.Callable, mocker: MockerFixture ): @@ -232,6 +235,7 @@ def test_normalization( adapter_mock.drop_table.assert_has_calls([call(relation_bla_bob)]) +@pytest.mark.slow def test_adapter_dispatch(sushi_test_project: Project, runtime_renderer: t.Callable): context = sushi_test_project.context renderer = runtime_renderer(context) @@ -244,6 +248,7 @@ def test_adapter_dispatch(sushi_test_project: Project, runtime_renderer: t.Calla @pytest.mark.parametrize("project_dialect", ["duckdb", "bigquery"]) +@pytest.mark.slow def test_adapter_map_snapshot_tables( sushi_test_project: Project, runtime_renderer: t.Callable, @@ -320,6 +325,7 @@ def test_quote_as_configured(): adapter.quote_as_configured("foo", "database") == "foo" +@pytest.mark.slow def test_adapter_get_relation_normalization( sushi_test_project: Project, runtime_renderer: t.Callable ): @@ -352,6 +358,7 @@ def test_adapter_get_relation_normalization( ) +@pytest.mark.slow def test_adapter_expand_target_column_types( sushi_test_project: Project, runtime_renderer: t.Callable, mocker: MockerFixture ): diff --git a/tests/dbt/test_config.py b/tests/dbt/test_config.py index 695c745c1d..1483225987 100644 --- a/tests/dbt/test_config.py +++ b/tests/dbt/test_config.py @@ -6,6 +6,8 @@ import pytest from dbt.adapters.base import BaseRelation, Column from pytest_mock import MockerFixture + +from sqlmesh.core.audit import StandaloneAudit from sqlmesh.core.config import Config, ModelDefaultsConfig from sqlmesh.core.dialect import jinja_query from sqlmesh.core.model import SqlModel @@ -82,7 +84,7 @@ def test_update(current: t.Dict[str, t.Any], new: t.Dict[str, t.Any], expected: assert {k: v for k, v in config.dict().items() if k in expected} == expected -def test_model_to_sqlmesh_fields(): +def test_model_to_sqlmesh_fields(dbt_dummy_postgres_config: PostgresConfig): model_config = ModelConfig( name="name", package_name="package", @@ -111,7 +113,7 @@ def test_model_to_sqlmesh_fields(): ) context = DbtContext() context.project_name = "Foo" - context.target = DuckDbConfig(name="target", schema="foo") + context.target = dbt_dummy_postgres_config model = model_config.to_sqlmesh(context) assert isinstance(model, SqlModel) @@ -119,7 +121,7 @@ def test_model_to_sqlmesh_fields(): assert model.description == "test model" assert ( model.render_query_or_raise().sql() - == 'SELECT 1 AS "a" FROM "memory"."foo"."table" AS "table"' + == 'SELECT 1 AS "a" FROM "dbname"."foo"."table" AS "table"' ) assert model.start == "Jan 1 2023" assert [col.sql() for col in model.partitioned_by] == ['"a"'] @@ -127,7 +129,7 @@ def test_model_to_sqlmesh_fields(): assert model.cron == "@hourly" assert model.interval_unit.value == "five_minute" assert model.stamp == "bar" - assert model.dialect == "duckdb" + assert model.dialect == "postgres" assert model.owner == "Sally" assert model.tags == ["test", "incremental"] kind = t.cast(IncrementalByUniqueKeyKind, model.kind) @@ -136,8 +138,8 @@ def test_model_to_sqlmesh_fields(): assert kind.on_destructive_change == OnDestructiveChange.ALLOW assert kind.on_additive_change == OnAdditiveChange.ALLOW assert ( - kind.merge_filter.sql(dialect=model.dialect) - == """55 > "__MERGE_SOURCE__"."b" AND "__MERGE_TARGET__"."session_start" > CURRENT_DATE + INTERVAL '7' DAY""" + kind.merge_filter.sql(dialect=model.dialect) # type: ignore + == """55 > "__MERGE_SOURCE__"."b" AND "__MERGE_TARGET__"."session_start" > CURRENT_DATE + INTERVAL '7'""" ) model = model_config.update_with({"dialect": "snowflake"}).to_sqlmesh(context) @@ -147,7 +149,7 @@ def test_model_to_sqlmesh_fields(): sqlmesh_config=Config(model_defaults=ModelDefaultsConfig(dialect="bigquery")) ) bq_default_context.project_name = "Foo" - bq_default_context.target = DuckDbConfig(name="target", schema="foo") + bq_default_context.target = dbt_dummy_postgres_config model_config.cluster_by = ["a", "`b`"] model = model_config.to_sqlmesh(bq_default_context) assert model.dialect == "bigquery" @@ -229,7 +231,7 @@ def test_test_to_sqlmesh_fields(): assert audit.dialect == "bigquery" -def test_singular_test_to_standalone_audit(): +def test_singular_test_to_standalone_audit(dbt_dummy_postgres_config: PostgresConfig): sql = "SELECT * FROM FOO.BAR WHERE cost > 100" test_config = TestConfig( name="bar_test", @@ -251,8 +253,8 @@ def test_singular_test_to_standalone_audit(): context = DbtContext() context.add_models({model.name: model}) context._project_name = "Foo" - context.target = DuckDbConfig(name="target", schema="foo") - standalone_audit = test_config.to_sqlmesh(context) + context.target = dbt_dummy_postgres_config + standalone_audit = t.cast(StandaloneAudit, test_config.to_sqlmesh(context)) assert standalone_audit.name == "bar_test" assert standalone_audit.description == "test description" @@ -260,12 +262,12 @@ def test_singular_test_to_standalone_audit(): assert standalone_audit.stamp == "bump" assert standalone_audit.cron == "@monthly" assert standalone_audit.interval_unit.value == "day" - assert standalone_audit.dialect == "duckdb" + assert standalone_audit.dialect == "postgres" assert standalone_audit.query == jinja_query(sql) - assert standalone_audit.depends_on == {'"memory"."foo"."bar"'} + assert standalone_audit.depends_on == {'"dbname"."foo"."bar"'} test_config.dialect_ = "bigquery" - standalone_audit = test_config.to_sqlmesh(context) + standalone_audit = t.cast(StandaloneAudit, test_config.to_sqlmesh(context)) assert standalone_audit.dialect == "bigquery" @@ -305,6 +307,7 @@ def test_model_config_sql_no_config(): ) +@pytest.mark.slow def test_variables(assert_exp_eq, sushi_test_project): # Case 1: using an undefined variable without a default value defined_variables = {} @@ -384,6 +387,7 @@ def test_variables(assert_exp_eq, sushi_test_project): assert sushi_test_project.packages["customers"].variables == expected_customer_variables +@pytest.mark.slow def test_nested_variables(sushi_test_project): model_config = ModelConfig( alias="sushi.test_nested", @@ -396,6 +400,7 @@ def test_nested_variables(sushi_test_project): assert sqlmesh_model.jinja_macros.global_objs["vars"]["nested_vars"] == {"some_nested_var": 2} +@pytest.mark.slow def test_source_config(sushi_test_project: Project): source_configs = sushi_test_project.packages["sushi"].sources assert set(source_configs) == { @@ -426,6 +431,7 @@ def test_source_config(sushi_test_project: Project): ) +@pytest.mark.slow def test_seed_config(sushi_test_project: Project, mocker: MockerFixture): seed_configs = sushi_test_project.packages["sushi"].seeds assert set(seed_configs) == {"waiter_names"} @@ -955,6 +961,7 @@ class CustomDbtLoader(DbtLoader): @pytest.mark.cicdonly +@pytest.mark.slow def test_db_type_to_relation_class(): from dbt.adapters.bigquery.relation import BigQueryRelation from dbt.adapters.databricks.relation import DatabricksRelation @@ -978,6 +985,7 @@ def test_db_type_to_relation_class(): @pytest.mark.cicdonly +@pytest.mark.slow def test_db_type_to_column_class(): from dbt.adapters.bigquery import BigQueryColumn from dbt.adapters.databricks.column import DatabricksColumn @@ -1013,6 +1021,7 @@ def test_variable_override(): assert project.packages["sushi"].variables["yet_another_var"] == 2 +@pytest.mark.slow def test_depends_on(assert_exp_eq, sushi_test_project): # Case 1: using an undefined variable without a default value context = sushi_test_project.context diff --git a/tests/dbt/test_integration.py b/tests/dbt/test_integration.py index 45c1422395..ee8c486ab2 100644 --- a/tests/dbt/test_integration.py +++ b/tests/dbt/test_integration.py @@ -7,7 +7,12 @@ import pandas as pd # noqa: TID253 import pytest -from dbt.cli.main import dbtRunner + +from sqlmesh.dbt.util import DBT_VERSION + +if DBT_VERSION >= (1, 5, 0): + from dbt.cli.main import dbtRunner # type: ignore + import time_machine from sqlmesh import Context @@ -303,6 +308,9 @@ def test_scd_type_2_by_time( test_type: TestType, invalidate_hard_deletes: bool, ): + if test_type.is_dbt_runtime and DBT_VERSION < (1, 5, 0): + pytest.skip("The dbt version being tested doesn't support the dbtRunner so skipping.") + run, adapter, context = self._init_test( create_scd_type_2_dbt_project, create_scd_type_2_sqlmesh_project, diff --git a/tests/dbt/test_manifest.py b/tests/dbt/test_manifest.py index efbd2687fd..7ad67c3585 100644 --- a/tests/dbt/test_manifest.py +++ b/tests/dbt/test_manifest.py @@ -63,9 +63,12 @@ def test_manifest_helper(caplog): assert models["items_no_hard_delete_snapshot"].invalidate_hard_deletes is False # Test versioned models - assert models["waiter_revenue_by_day_v1"].version == 1 - assert models["waiter_revenue_by_day_v2"].version == 2 - assert "waiter_revenue_by_day" not in models + if DBT_VERSION >= (1, 5, 0): + assert models["waiter_revenue_by_day_v1"].version == 1 + assert models["waiter_revenue_by_day_v2"].version == 2 + assert "waiter_revenue_by_day" not in models + else: + assert "waiter_revenue_by_day" in models waiter_as_customer_by_day_config = models["waiter_as_customer_by_day"] assert waiter_as_customer_by_day_config.dependencies == Dependencies( @@ -77,7 +80,10 @@ def test_manifest_helper(caplog): assert waiter_as_customer_by_day_config.cluster_by == ["ds"] assert waiter_as_customer_by_day_config.time_column == "ds" - waiter_revenue_by_day_config = models["waiter_revenue_by_day_v2"] + if DBT_VERSION >= (1, 5, 0): + waiter_revenue_by_day_config = models["waiter_revenue_by_day_v2"] + else: + waiter_revenue_by_day_config = models["waiter_revenue_by_day"] assert waiter_revenue_by_day_config.dependencies == Dependencies( macros={ MacroReference(name="dynamic_var_name_dependency"), @@ -218,7 +224,12 @@ def test_source_meta_external_location(): sources["parquet_file.items"].relation_info, api.Relation, api.quote_policy ) assert relation.identifier == "items" - assert relation.render() == "read_parquet('path/to/external/items.parquet')" + expected = ( + "read_parquet('path/to/external/items.parquet')" + if DBT_VERSION >= (1, 4, 0) + else '"main"."parquet_file".items' + ) + assert relation.render() == expected @pytest.mark.xdist_group("dbt_manifest") diff --git a/tests/dbt/test_model.py b/tests/dbt/test_model.py index 030f2ec723..df9f229900 100644 --- a/tests/dbt/test_model.py +++ b/tests/dbt/test_model.py @@ -6,6 +6,7 @@ from sqlmesh.dbt.common import Dependencies from sqlmesh.dbt.context import DbtContext from sqlmesh.dbt.model import ModelConfig +from sqlmesh.dbt.target import PostgresConfig from sqlmesh.dbt.test import TestConfig from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.yaml import YAML @@ -50,7 +51,10 @@ def test_model_test_circular_references() -> None: downstream_model.check_for_circular_test_refs(context) -def test_load_invalid_ref_audit_constraints(tmp_path: Path, caplog) -> None: +@pytest.mark.slow +def test_load_invalid_ref_audit_constraints( + tmp_path: Path, caplog, dbt_dummy_postgres_config: PostgresConfig +) -> None: yaml = YAML() dbt_project_dir = tmp_path / "dbt" dbt_project_dir.mkdir() diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index 33c7132551..320b036e6d 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -6,9 +6,15 @@ from pathlib import Path from unittest.mock import patch +from sqlmesh.dbt.util import DBT_VERSION + import pytest from dbt.adapters.base import BaseRelation -from dbt.exceptions import CompilationError + +if DBT_VERSION >= (1, 4, 0): + from dbt.exceptions import CompilationError +else: + from dbt.exceptions import CompilationException as CompilationError # type: ignore import time_machine from pytest_mock.plugin import MockerFixture from sqlglot import exp, parse_one @@ -47,8 +53,14 @@ from sqlmesh.dbt.model import Materialization, ModelConfig from sqlmesh.dbt.project import Project from sqlmesh.dbt.relation import Policy -from sqlmesh.dbt.seed import SeedConfig, Integer -from sqlmesh.dbt.target import BigQueryConfig, DuckDbConfig, SnowflakeConfig, ClickhouseConfig +from sqlmesh.dbt.seed import SeedConfig +from sqlmesh.dbt.target import ( + BigQueryConfig, + DuckDbConfig, + SnowflakeConfig, + ClickhouseConfig, + PostgresConfig, +) from sqlmesh.dbt.test import TestConfig from sqlmesh.utils.errors import ConfigError, MacroEvalError, SQLMeshError from sqlmesh.utils.jinja import MacroReference @@ -56,9 +68,9 @@ pytestmark = [pytest.mark.dbt, pytest.mark.slow] -def test_model_name(): +def test_model_name(dbt_dummy_postgres_config: PostgresConfig): context = DbtContext() - context._target = DuckDbConfig(name="duckdb", schema="foo") + context._target = dbt_dummy_postgres_config assert ModelConfig(schema="foo", path="models/bar.sql").canonical_name(context) == "foo.bar" assert ( ModelConfig(schema="foo", path="models/bar.sql", alias="baz").canonical_name(context) @@ -66,10 +78,9 @@ def test_model_name(): ) assert ( ModelConfig( - database="memory", schema="foo", path="models/bar.sql", alias="baz" + database="dbname", schema="foo", path="models/bar.sql", alias="baz" ).canonical_name(context) == "foo.baz" - == "foo.baz" ) assert ( ModelConfig( @@ -680,7 +691,9 @@ def test_seed_column_inference(tmp_path): context.target = DuckDbConfig(name="target", schema="test") sqlmesh_seed = seed.to_sqlmesh(context) assert sqlmesh_seed.columns_to_types == { - "int_col": exp.DataType.build("int"), + "int_col": exp.DataType.build("int") + if DBT_VERSION >= (1, 8, 0) + else exp.DataType.build("double"), "double_col": exp.DataType.build("double"), "datetime_col": exp.DataType.build("datetime"), "date_col": exp.DataType.build("date"), @@ -793,6 +806,12 @@ def test_seed_column_order(tmp_path): def test_agate_integer_cast(): + # Not all dbt versions have agate.Integer + if DBT_VERSION < (1, 7, 0): + pytest.skip("agate.Integer not available") + + from sqlmesh.dbt.seed import Integer + agate_integer = Integer(null_values=("null", "")) assert agate_integer.cast("1") == 1 assert agate_integer.cast(1) == 1