Skip to content

Commit 3d144c5

Browse files
authored
Feat(dbt): Add support for adapter.expand_target_column_types (#5206)
1 parent 4d8e831 commit 3d144c5

File tree

3 files changed

+134
-13
lines changed

3 files changed

+134
-13
lines changed

docs/integrations/dbt.md

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -344,26 +344,23 @@ Model documentation is available in the [SQLMesh UI](../quickstart/ui.md#2-open-
344344

345345
SQLMesh supports running dbt projects using the majority of dbt jinja methods, including:
346346

347-
| Method | Method | Method | Method |
348-
| ----------- | -------------- | ------------ | ------- |
349-
| adapter (*) | env_var | project_name | target |
350-
| as_bool | exceptions | ref | this |
351-
| as_native | from_yaml | return | to_yaml |
352-
| as_number | is_incremental | run_query | var |
353-
| as_text | load_result | schema | zip |
354-
| api | log | set | |
355-
| builtins | modules | source | |
356-
| config | print | statement | |
357-
358-
\* `adapter.expand_target_column_types` is not currently supported.
347+
| Method | Method | Method | Method |
348+
| --------- | -------------- | ------------ | ------- |
349+
| adapter | env_var | project_name | target |
350+
| as_bool | exceptions | ref | this |
351+
| as_native | from_yaml | return | to_yaml |
352+
| as_number | is_incremental | run_query | var |
353+
| as_text | load_result | schema | zip |
354+
| api | log | set | |
355+
| builtins | modules | source | |
356+
| config | print | statement | |
359357

360358
## Unsupported dbt jinja methods
361359

362360
The dbt jinja methods that are not currently supported are:
363361

364362
* debug
365363
* selected_sources
366-
* adapter.expand_target_column_types
367364
* graph.nodes.values
368365
* graph.metrics.values
369366

sqlmesh/dbt/adapter.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from sqlmesh.utils.errors import ConfigError, ParsetimeAdapterCallError
1313
from sqlmesh.utils.jinja import JinjaMacroRegistry
1414
from sqlmesh.utils import AttributeDict
15+
from sqlmesh.core.schema_diff import TableAlterOperation
1516

1617
if t.TYPE_CHECKING:
1718
import agate
@@ -85,6 +86,12 @@ def drop_schema(self, relation: BaseRelation) -> None:
8586
def drop_relation(self, relation: BaseRelation) -> None:
8687
"""Drops a relation (table) in the target database."""
8788

89+
@abc.abstractmethod
90+
def expand_target_column_types(
91+
self, from_relation: BaseRelation, to_relation: BaseRelation
92+
) -> None:
93+
"""Expand to_relation's column types to match those of from_relation."""
94+
8895
@abc.abstractmethod
8996
def rename_relation(self, from_relation: BaseRelation, to_relation: BaseRelation) -> None:
9097
"""Renames a relation (table) in the target database."""
@@ -213,6 +220,11 @@ def drop_schema(self, relation: BaseRelation) -> None:
213220
def drop_relation(self, relation: BaseRelation) -> None:
214221
self._raise_parsetime_adapter_call_error("drop relation")
215222

223+
def expand_target_column_types(
224+
self, from_relation: BaseRelation, to_relation: BaseRelation
225+
) -> None:
226+
self._raise_parsetime_adapter_call_error("expand target column types")
227+
216228
def rename_relation(self, from_relation: BaseRelation, to_relation: BaseRelation) -> None:
217229
self._raise_parsetime_adapter_call_error("rename relation")
218230

@@ -355,6 +367,39 @@ def drop_relation(self, relation: BaseRelation) -> None:
355367
if relation.schema is not None and relation.identifier is not None:
356368
self.engine_adapter.drop_table(self._normalize(self._relation_to_table(relation)))
357369

370+
def expand_target_column_types(
371+
self, from_relation: BaseRelation, to_relation: BaseRelation
372+
) -> None:
373+
from_dbt_columns = {c.name: c for c in self.get_columns_in_relation(from_relation)}
374+
to_dbt_columns = {c.name: c for c in self.get_columns_in_relation(to_relation)}
375+
376+
from_table_name = self._normalize(self._relation_to_table(from_relation))
377+
to_table_name = self._normalize(self._relation_to_table(to_relation))
378+
379+
from_columns = self.engine_adapter.columns(from_table_name)
380+
to_columns = self.engine_adapter.columns(to_table_name)
381+
382+
current_columns = {}
383+
new_columns = {}
384+
for column_name, from_column in from_dbt_columns.items():
385+
target_column = to_dbt_columns.get(column_name)
386+
if target_column is not None and target_column.can_expand_to(from_column):
387+
current_columns[column_name] = to_columns[column_name]
388+
new_columns[column_name] = from_columns[column_name]
389+
390+
alter_expressions = t.cast(
391+
t.List[TableAlterOperation],
392+
self.engine_adapter.schema_differ.compare_columns(
393+
to_table_name,
394+
current_columns,
395+
new_columns,
396+
ignore_destructive=True,
397+
),
398+
)
399+
400+
if alter_expressions:
401+
self.engine_adapter.alter_table(alter_expressions)
402+
358403
def rename_relation(self, from_relation: BaseRelation, to_relation: BaseRelation) -> None:
359404
old_table_name = self._normalize(self._relation_to_table(from_relation))
360405
new_table_name = self._normalize(self._relation_to_table(to_relation))

tests/dbt/test_adapter.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from sqlmesh.dbt.target import BigQueryConfig, SnowflakeConfig
1919
from sqlmesh.utils.errors import ConfigError
2020
from sqlmesh.utils.jinja import JinjaMacroRegistry
21+
from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterChangeColumnTypeOperation
2122

2223
pytestmark = pytest.mark.dbt
2324

@@ -349,3 +350,81 @@ def test_adapter_get_relation_normalization(
349350
renderer("{{ adapter.list_relations(database=None, schema='foo') }}")
350351
== '[<SnowflakeRelation "memory"."FOO"."BAR">]'
351352
)
353+
354+
355+
def test_adapter_expand_target_column_types(
356+
sushi_test_project: Project, runtime_renderer: t.Callable, mocker: MockerFixture
357+
):
358+
from sqlmesh.core.engine_adapter.base import DataObject, DataObjectType
359+
360+
data_object_from = DataObject(
361+
catalog="test", schema="foo", name="from_table", type=DataObjectType.TABLE
362+
)
363+
data_object_to = DataObject(
364+
catalog="test", schema="foo", name="to_table", type=DataObjectType.TABLE
365+
)
366+
from_columns = {
367+
"int_col": exp.DataType.build("int"),
368+
"same_text_col": exp.DataType.build("varchar(1)"), # varchar(1) -> varchar(1)
369+
"unexpandable_text_col": exp.DataType.build("varchar(2)"), # varchar(4) -> varchar(2)
370+
"expandable_text_col1": exp.DataType.build("varchar(16)"), # varchar(8) -> varchar(16)
371+
"expandable_text_col2": exp.DataType.build("varchar(64)"), # varchar(32) -> varchar(64)
372+
}
373+
to_columns = {
374+
"int_col": exp.DataType.build("int"),
375+
"same_text_col": exp.DataType.build("varchar(1)"),
376+
"unexpandable_text_col": exp.DataType.build("varchar(4)"),
377+
"expandable_text_col1": exp.DataType.build("varchar(8)"),
378+
"expandable_text_col2": exp.DataType.build("varchar(32)"),
379+
}
380+
adapter_mock = mocker.MagicMock()
381+
adapter_mock.default_catalog = "test"
382+
adapter_mock.get_data_object.side_effect = [data_object_from, data_object_to]
383+
# columns() is called 4 times, twice by adapter.get_columns_in_relation() and twice by the engine_adapter
384+
adapter_mock.columns.side_effect = [
385+
from_columns,
386+
to_columns,
387+
from_columns,
388+
to_columns,
389+
]
390+
adapter_mock.schema_differ = SchemaDiffer()
391+
392+
context = sushi_test_project.context
393+
renderer = runtime_renderer(context, engine_adapter=adapter_mock)
394+
395+
renderer("""
396+
{%- set from_relation = adapter.get_relation(
397+
database=None,
398+
schema='foo',
399+
identifier='from_table') -%}
400+
401+
{% set to_relation = adapter.get_relation(
402+
database=None,
403+
schema='foo',
404+
identifier='to_table') -%}
405+
406+
{% do adapter.expand_target_column_types(from_relation, to_relation) %}
407+
""")
408+
adapter_mock.get_data_object.assert_has_calls(
409+
[
410+
call(exp.to_table('"test"."foo"."from_table"')),
411+
call(exp.to_table('"test"."foo"."to_table"')),
412+
]
413+
)
414+
assert len(adapter_mock.alter_table.call_args.args) == 1
415+
alter_expressions = adapter_mock.alter_table.call_args.args[0]
416+
assert len(alter_expressions) == 2
417+
alter_operation1 = alter_expressions[0]
418+
assert isinstance(alter_operation1, TableAlterChangeColumnTypeOperation)
419+
assert alter_operation1.expression == parse_one(
420+
"""ALTER TABLE "test"."foo"."to_table"
421+
ALTER COLUMN expandable_text_col1
422+
SET DATA TYPE VARCHAR(16)"""
423+
)
424+
alter_operation2 = alter_expressions[1]
425+
assert isinstance(alter_operation2, TableAlterChangeColumnTypeOperation)
426+
assert alter_operation2.expression == parse_one(
427+
"""ALTER TABLE "test"."foo"."to_table"
428+
ALTER COLUMN expandable_text_col2
429+
SET DATA TYPE VARCHAR(64)"""
430+
)

0 commit comments

Comments
 (0)