Skip to content

Commit 5c252f3

Browse files
authored
feat: mimic dbt nuanced on_schema_change behavior (#5203)
1 parent 672806a commit 5c252f3

24 files changed

+329
-148
lines changed

sqlmesh/core/config/connection.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ class ConnectionConfig(abc.ABC, BaseConfig):
100100
register_comments: bool
101101
pre_ping: bool
102102
pretty_sql: bool = False
103+
schema_differ_overrides: t.Optional[t.Dict[str, t.Any]] = None
103104

104105
# Whether to share a single connection across threads or create a new connection per thread.
105106
shared_connection: t.ClassVar[bool] = False
@@ -174,6 +175,7 @@ def create_engine_adapter(
174175
pre_ping=self.pre_ping,
175176
pretty_sql=self.pretty_sql,
176177
shared_connection=self.shared_connection,
178+
schema_differ_overrides=self.schema_differ_overrides,
177179
**self._extra_engine_config,
178180
)
179181

sqlmesh/core/engine_adapter/athena.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class AthenaEngineAdapter(PandasNativeFetchDFSupportMixin, RowDiffMixin):
3939
# CTAS, Views: No comment support at all
4040
COMMENT_CREATION_TABLE = CommentCreationTable.UNSUPPORTED
4141
COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED
42-
SCHEMA_DIFFER = TrinoEngineAdapter.SCHEMA_DIFFER
42+
SCHEMA_DIFFER_KWARGS = TrinoEngineAdapter.SCHEMA_DIFFER_KWARGS
4343
MAX_TIMESTAMP_PRECISION = 3 # copied from Trino
4444
# Athena does not deal with comments well, e.g:
4545
# >>> self._execute('/* test */ DESCRIBE foo')

sqlmesh/core/engine_adapter/base.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import logging
1515
import sys
1616
import typing as t
17-
from functools import partial
17+
from functools import cached_property, partial
1818

1919
from sqlglot import Dialect, exp
2020
from sqlglot.errors import ErrorLevel
@@ -109,7 +109,7 @@ class EngineAdapter:
109109
SUPPORTS_MANAGED_MODELS = False
110110
SUPPORTS_CREATE_DROP_CATALOG = False
111111
SUPPORTED_DROP_CASCADE_OBJECT_KINDS: t.List[str] = []
112-
SCHEMA_DIFFER = SchemaDiffer()
112+
SCHEMA_DIFFER_KWARGS: t.Dict[str, t.Any] = {}
113113
SUPPORTS_TUPLE_IN = True
114114
HAS_VIEW_BINDING = False
115115
SUPPORTS_REPLACE_TABLE = True
@@ -132,6 +132,7 @@ def __init__(
132132
pretty_sql: bool = False,
133133
shared_connection: bool = False,
134134
correlation_id: t.Optional[CorrelationId] = None,
135+
schema_differ_overrides: t.Optional[t.Dict[str, t.Any]] = None,
135136
**kwargs: t.Any,
136137
):
137138
self.dialect = dialect.lower() or self.DIALECT
@@ -154,6 +155,7 @@ def __init__(
154155
self._pretty_sql = pretty_sql
155156
self._multithreaded = multithreaded
156157
self.correlation_id = correlation_id
158+
self._schema_differ_overrides = schema_differ_overrides
157159

158160
def with_settings(self, **kwargs: t.Any) -> EngineAdapter:
159161
extra_kwargs = {
@@ -204,6 +206,15 @@ def comments_enabled(self) -> bool:
204206
def catalog_support(self) -> CatalogSupport:
205207
return CatalogSupport.UNSUPPORTED
206208

209+
@cached_property
210+
def schema_differ(self) -> SchemaDiffer:
211+
return SchemaDiffer(
212+
**{
213+
**self.SCHEMA_DIFFER_KWARGS,
214+
**(self._schema_differ_overrides or {}),
215+
}
216+
)
217+
207218
@classmethod
208219
def _casted_columns(
209220
cls,
@@ -1101,7 +1112,7 @@ def get_alter_operations(
11011112
"""
11021113
return t.cast(
11031114
t.List[TableAlterOperation],
1104-
self.SCHEMA_DIFFER.compare_columns(
1115+
self.schema_differ.compare_columns(
11051116
current_table_name,
11061117
self.columns(current_table_name),
11071118
self.columns(target_table_name),

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
set_catalog,
2323
)
2424
from sqlmesh.core.node import IntervalUnit
25-
from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterOperation
25+
from sqlmesh.core.schema_diff import TableAlterOperation, NestedSupport
2626
from sqlmesh.utils import optional_import, get_source_columns_to_types
2727
from sqlmesh.utils.date import to_datetime
2828
from sqlmesh.utils.errors import SQLMeshError
@@ -68,8 +68,8 @@ class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin, ClusteredByMixin, Row
6868
MAX_COLUMN_COMMENT_LENGTH = 1024
6969
SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA"]
7070

71-
SCHEMA_DIFFER = SchemaDiffer(
72-
compatible_types={
71+
SCHEMA_DIFFER_KWARGS = {
72+
"compatible_types": {
7373
exp.DataType.build("INT64", dialect=DIALECT): {
7474
exp.DataType.build("NUMERIC", dialect=DIALECT),
7575
exp.DataType.build("FLOAT64", dialect=DIALECT),
@@ -83,17 +83,17 @@ class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin, ClusteredByMixin, Row
8383
exp.DataType.build("DATETIME", dialect=DIALECT),
8484
},
8585
},
86-
coerceable_types={
86+
"coerceable_types": {
8787
exp.DataType.build("FLOAT64", dialect=DIALECT): {
8888
exp.DataType.build("BIGNUMERIC", dialect=DIALECT),
8989
},
9090
},
91-
support_coercing_compatible_types=True,
92-
parameterized_type_defaults={
91+
"support_coercing_compatible_types": True,
92+
"parameterized_type_defaults": {
9393
exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(38, 9), (0,)],
9494
exp.DataType.build("BIGDECIMAL", dialect=DIALECT).this: [(76.76, 38), (0,)],
9595
},
96-
types_with_unlimited_length={
96+
"types_with_unlimited_length": {
9797
# parameterized `STRING(n)` can ALTER to unparameterized `STRING`
9898
exp.DataType.build("STRING", dialect=DIALECT).this: {
9999
exp.DataType.build("STRING", dialect=DIALECT).this,
@@ -103,9 +103,8 @@ class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin, ClusteredByMixin, Row
103103
exp.DataType.build("BYTES", dialect=DIALECT).this,
104104
},
105105
},
106-
support_nested_operations=True,
107-
support_nested_drop=False,
108-
)
106+
"nested_support": NestedSupport.ALL_BUT_DROP,
107+
}
109108

110109
@property
111110
def client(self) -> BigQueryClient:

sqlmesh/core/engine_adapter/clickhouse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
CommentCreationView,
1616
InsertOverwriteStrategy,
1717
)
18-
from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterOperation
18+
from sqlmesh.core.schema_diff import TableAlterOperation
1919
from sqlmesh.utils import get_source_columns_to_types
2020

2121
if t.TYPE_CHECKING:
@@ -37,7 +37,7 @@ class ClickhouseEngineAdapter(EngineAdapterWithIndexSupport, LogicalMergeMixin):
3737
SUPPORTS_REPLACE_TABLE = False
3838
COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY
3939

40-
SCHEMA_DIFFER = SchemaDiffer()
40+
SCHEMA_DIFFER_KWARGS = {}
4141

4242
DEFAULT_TABLE_ENGINE = "MergeTree"
4343
ORDER_BY_TABLE_ENGINE_REGEX = "^.*?MergeTree.*$"

sqlmesh/core/engine_adapter/databricks.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616
from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter
1717
from sqlmesh.core.node import IntervalUnit
18-
from sqlmesh.core.schema_diff import SchemaDiffer
18+
from sqlmesh.core.schema_diff import NestedSupport
1919
from sqlmesh.engines.spark.db_api.spark_session import connection, SparkSessionConnection
2020
from sqlmesh.utils.errors import SQLMeshError, MissingDefaultCatalogError
2121

@@ -34,15 +34,14 @@ class DatabricksEngineAdapter(SparkEngineAdapter):
3434
SUPPORTS_CLONING = True
3535
SUPPORTS_MATERIALIZED_VIEWS = True
3636
SUPPORTS_MATERIALIZED_VIEW_SCHEMA = True
37-
SCHEMA_DIFFER = SchemaDiffer(
38-
support_positional_add=True,
39-
support_nested_operations=True,
40-
support_nested_drop=True,
41-
array_element_selector="element",
42-
parameterized_type_defaults={
37+
SCHEMA_DIFFER_KWARGS = {
38+
"support_positional_add": True,
39+
"nested_support": NestedSupport.ALL,
40+
"array_element_selector": "element",
41+
"parameterized_type_defaults": {
4342
exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(10, 0), (0,)],
4443
},
45-
)
44+
}
4645

4746
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
4847
super().__init__(*args, **kwargs)

sqlmesh/core/engine_adapter/duckdb.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
SourceQuery,
1919
set_catalog,
2020
)
21-
from sqlmesh.core.schema_diff import SchemaDiffer
2221

2322
if t.TYPE_CHECKING:
2423
from sqlmesh.core._typing import SchemaName, TableName
@@ -29,11 +28,11 @@
2928
class DuckDBEngineAdapter(LogicalMergeMixin, GetCurrentCatalogFromFunctionMixin, RowDiffMixin):
3029
DIALECT = "duckdb"
3130
SUPPORTS_TRANSACTIONS = False
32-
SCHEMA_DIFFER = SchemaDiffer(
33-
parameterized_type_defaults={
31+
SCHEMA_DIFFER_KWARGS = {
32+
"parameterized_type_defaults": {
3433
exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(18, 3), (0,)],
3534
},
36-
)
35+
}
3736
COMMENT_CREATION_TABLE = CommentCreationTable.COMMENT_COMMAND_ONLY
3837
COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY
3938
SUPPORTS_CREATE_DROP_CATALOG = True

sqlmesh/core/engine_adapter/mixins.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,9 @@ def _default_precision_to_max(
259259
) -> t.Dict[str, exp.DataType]:
260260
# get default lengths for types that support "max" length
261261
types_with_max_default_param = {
262-
k: [self.SCHEMA_DIFFER.parameterized_type_defaults[k][0][0]]
263-
for k in self.SCHEMA_DIFFER.max_parameter_length
264-
if k in self.SCHEMA_DIFFER.parameterized_type_defaults
262+
k: [self.schema_differ.parameterized_type_defaults[k][0][0]]
263+
for k in self.schema_differ.max_parameter_length
264+
if k in self.schema_differ.parameterized_type_defaults
265265
}
266266

267267
# Redshift and MSSQL have a bug where CTAS statements have non-deterministic types. If a LIMIT
@@ -270,7 +270,7 @@ def _default_precision_to_max(
270270
# and supports "max" length, we convert it to "max" length to prevent inadvertent data truncation.
271271
for col_name, col_type in columns_to_types.items():
272272
if col_type.this in types_with_max_default_param and col_type.expressions:
273-
parameter = self.SCHEMA_DIFFER.get_type_parameters(col_type)
273+
parameter = self.schema_differ.get_type_parameters(col_type)
274274
type_default = types_with_max_default_param[col_type.this]
275275
if parameter == type_default:
276276
col_type.set("expressions", [exp.DataTypeParam(this=exp.var("max"))])

sqlmesh/core/engine_adapter/mssql.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
SourceQuery,
3131
set_catalog,
3232
)
33-
from sqlmesh.core.schema_diff import SchemaDiffer
3433
from sqlmesh.utils import get_source_columns_to_types
3534

3635
if t.TYPE_CHECKING:
@@ -54,8 +53,8 @@ class MSSQLEngineAdapter(
5453
COMMENT_CREATION_TABLE = CommentCreationTable.UNSUPPORTED
5554
COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED
5655
SUPPORTS_REPLACE_TABLE = False
57-
SCHEMA_DIFFER = SchemaDiffer(
58-
parameterized_type_defaults={
56+
SCHEMA_DIFFER_KWARGS = {
57+
"parameterized_type_defaults": {
5958
exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(18, 0), (0,)],
6059
exp.DataType.build("BINARY", dialect=DIALECT).this: [(1,)],
6160
exp.DataType.build("VARBINARY", dialect=DIALECT).this: [(1,)],
@@ -67,12 +66,12 @@ class MSSQLEngineAdapter(
6766
exp.DataType.build("DATETIME2", dialect=DIALECT).this: [(7,)],
6867
exp.DataType.build("DATETIMEOFFSET", dialect=DIALECT).this: [(7,)],
6968
},
70-
max_parameter_length={
69+
"max_parameter_length": {
7170
exp.DataType.build("VARBINARY", dialect=DIALECT).this: 2147483647, # 2 GB
7271
exp.DataType.build("VARCHAR", dialect=DIALECT).this: 2147483647,
7372
exp.DataType.build("NVARCHAR", dialect=DIALECT).this: 2147483647,
7473
},
75-
)
74+
}
7675
VARIABLE_LENGTH_DATA_TYPES = {"binary", "varbinary", "char", "varchar", "nchar", "nvarchar"}
7776

7877
@property

sqlmesh/core/engine_adapter/mysql.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
DataObjectType,
2020
set_catalog,
2121
)
22-
from sqlmesh.core.schema_diff import SchemaDiffer
2322

2423
if t.TYPE_CHECKING:
2524
from sqlmesh.core._typing import SchemaName, TableName
@@ -40,8 +39,8 @@ class MySQLEngineAdapter(
4039
MAX_COLUMN_COMMENT_LENGTH = 1024
4140
SUPPORTS_REPLACE_TABLE = False
4241
MAX_IDENTIFIER_LENGTH = 64
43-
SCHEMA_DIFFER = SchemaDiffer(
44-
parameterized_type_defaults={
42+
SCHEMA_DIFFER_KWARGS = {
43+
"parameterized_type_defaults": {
4544
exp.DataType.build("BIT", dialect=DIALECT).this: [(1,)],
4645
exp.DataType.build("BINARY", dialect=DIALECT).this: [(1,)],
4746
exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(10, 0), (0,)],
@@ -52,7 +51,7 @@ class MySQLEngineAdapter(
5251
exp.DataType.build("DATETIME", dialect=DIALECT).this: [(0,)],
5352
exp.DataType.build("TIMESTAMP", dialect=DIALECT).this: [(0,)],
5453
},
55-
)
54+
}
5655

5756
def get_current_catalog(self) -> t.Optional[str]:
5857
"""Returns the catalog name of the current connection."""

0 commit comments

Comments
 (0)