Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions sqlmesh/core/engine_adapter/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
)
from sqlmesh.utils.errors import SQLMeshError
from sqlmesh.utils.connection_pool import ConnectionPool
from sqlmesh.core.schema_diff import TableAlterOperation
from sqlmesh.utils import random_id


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -154,6 +156,113 @@ def set_current_catalog(self, catalog_name: str) -> None:
f"Unable to switch catalog to {catalog_name}, catalog ended up as {catalog_after_switch}"
)

def alter_table(
self, alter_expressions: t.Union[t.List[exp.Alter], t.List[TableAlterOperation]]
) -> None:
"""
Applies alter expressions to a table. Fabric has limited support for ALTER TABLE,
so this method implements a workaround for column type changes.
This method is self-contained and sets its own catalog context.
"""
if not alter_expressions:
return

# Get the target table from the first expression to determine the correct catalog.
first_op = alter_expressions[0]
expression = first_op.expression if isinstance(first_op, TableAlterOperation) else first_op
if not isinstance(expression, exp.Alter) or not expression.this.catalog:
# Fallback for unexpected scenarios
logger.warning(
"Could not determine catalog from alter expression, executing with current context."
)
super().alter_table(alter_expressions)
return

target_catalog = expression.this.catalog
self.set_current_catalog(target_catalog)

with self.transaction():
for op in alter_expressions:
expression = op.expression if isinstance(op, TableAlterOperation) else op

if not isinstance(expression, exp.Alter):
self.execute(expression)
continue

for action in expression.actions:
table_name = expression.this

table_name_without_catalog = table_name.copy()
table_name_without_catalog.set("catalog", None)

is_type_change = isinstance(action, exp.AlterColumn) and action.args.get(
"dtype"
)

if is_type_change:
column_to_alter = action.this
new_type = action.args["dtype"]
temp_column_name_str = f"{column_to_alter.name}__{random_id(short=True)}"
temp_column_name = exp.to_identifier(temp_column_name_str)

logger.info(
"Applying workaround for column '%s' on table '%s' to change type to '%s'.",
column_to_alter.sql(),
table_name.sql(),
new_type.sql(),
)

# Step 1: Add a temporary column.
add_column_expr = exp.Alter(
this=table_name_without_catalog.copy(),
kind="TABLE",
actions=[
exp.ColumnDef(this=temp_column_name.copy(), kind=new_type.copy())
],
)
add_sql = self._to_sql(add_column_expr)
self.execute(add_sql)

# Step 2: Copy and cast data.
update_sql = self._to_sql(
exp.Update(
this=table_name_without_catalog.copy(),
expressions=[
exp.EQ(
this=temp_column_name.copy(),
expression=exp.Cast(
this=column_to_alter.copy(), to=new_type.copy()
),
)
],
)
)
self.execute(update_sql)

# Step 3: Drop the original column.
drop_sql = self._to_sql(
exp.Alter(
this=table_name_without_catalog.copy(),
kind="TABLE",
actions=[exp.Drop(this=column_to_alter.copy(), kind="COLUMN")],
)
)
self.execute(drop_sql)

# Step 4: Rename the temporary column.
old_name_qualified = f"{table_name_without_catalog.sql(dialect=self.dialect)}.{temp_column_name.sql(dialect=self.dialect)}"
new_name_unquoted = column_to_alter.sql(
dialect=self.dialect, identify=False
)
rename_sql = f"EXEC sp_rename '{old_name_qualified}', '{new_name_unquoted}', 'COLUMN'"
self.execute(rename_sql)
else:
# For other alterations, execute directly.
direct_alter_expr = exp.Alter(
this=table_name_without_catalog.copy(), kind="TABLE", actions=[action]
)
self.execute(direct_alter_expr)


class FabricHttpClient:
def __init__(self, tenant_id: str, workspace_id: str, client_id: str, client_secret: str):
Expand Down
55 changes: 55 additions & 0 deletions tests/core/engine_adapter/test_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,58 @@ def test_replace_query(adapter: FabricEngineAdapter, mocker: MockerFixture):
"TRUNCATE TABLE [test_table];",
"INSERT INTO [test_table] ([a]) SELECT [a] FROM [tbl];",
]


def test_alter_table_column_type_workaround(adapter: FabricEngineAdapter, mocker: MockerFixture):
"""
Tests the alter_table method's workaround for changing a column's data type.
"""
# Mock set_current_catalog to avoid connection pool side effects
set_catalog_mock = mocker.patch.object(adapter, "set_current_catalog")
# Mock random_id to have a predictable temporary column name
mocker.patch("sqlmesh.core.engine_adapter.fabric.random_id", return_value="abcdef")

alter_expression = exp.Alter(
this=exp.to_table("my_db.my_schema.my_table"),
actions=[
exp.AlterColumn(
this=exp.to_column("col_a"),
dtype=exp.DataType.build("BIGINT"),
)
],
)

adapter.alter_table([alter_expression])

set_catalog_mock.assert_called_once_with("my_db")

expected_calls = [
"ALTER TABLE [my_schema].[my_table] ADD [col_a__abcdef] BIGINT;",
"UPDATE [my_schema].[my_table] SET [col_a__abcdef] = CAST([col_a] AS BIGINT);",
"ALTER TABLE [my_schema].[my_table] DROP COLUMN [col_a];",
"EXEC sp_rename 'my_schema.my_table.col_a__abcdef', 'col_a', 'COLUMN'",
]

assert to_sql_calls(adapter) == expected_calls


def test_alter_table_direct_alteration(adapter: FabricEngineAdapter, mocker: MockerFixture):
"""
Tests the alter_table method for direct alterations like adding a column.
"""
set_catalog_mock = mocker.patch.object(adapter, "set_current_catalog")

alter_expression = exp.Alter(
this=exp.to_table("my_db.my_schema.my_table"),
actions=[exp.ColumnDef(this=exp.to_column("new_col"), kind=exp.DataType.build("INT"))],
)

adapter.alter_table([alter_expression])

set_catalog_mock.assert_called_once_with("my_db")

expected_calls = [
"ALTER TABLE [my_schema].[my_table] ADD [new_col] INT;",
]

assert to_sql_calls(adapter) == expected_calls