Skip to content

Commit 8ae058f

Browse files
committed
Remove time travel test for cloud engines, handle pyspark DFs in dbx
1 parent c9a6f83 commit 8ae058f

File tree

2 files changed

+72
-52
lines changed

2 files changed

+72
-52
lines changed

sqlmesh/core/engine_adapter/databricks.py

Lines changed: 57 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
SourceQuery,
1515
)
1616
from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter
17+
from sqlmesh.engines.spark.db_api.spark_session import SparkSessionCursor
1718
from sqlmesh.core.node import IntervalUnit
1819
from sqlmesh.core.schema_diff import SchemaDiffer
1920
from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker
@@ -380,38 +381,59 @@ def _record_execution_stats(
380381
except:
381382
return
382383

383-
history = self.cursor.fetchall_arrow()
384-
if history.num_rows:
385-
history_df = history.to_pandas()
386-
write_df = history_df[history_df["operation"] == "WRITE"]
387-
write_df = write_df[write_df["timestamp"] == write_df["timestamp"].max()]
388-
if not write_df.empty:
389-
metrics = write_df["operationMetrics"][0]
390-
if metrics:
391-
rowcount = None
392-
rowcount_str = [
393-
metric[1] for metric in metrics if metric[0] == "numOutputRows"
394-
]
395-
if rowcount_str:
396-
try:
397-
rowcount = int(rowcount_str[0])
398-
except (TypeError, ValueError):
399-
pass
400-
401-
bytes_processed = None
402-
bytes_str = [
403-
metric[1] for metric in metrics if metric[0] == "numOutputBytes"
404-
]
405-
if bytes_str:
406-
try:
407-
bytes_processed = int(bytes_str[0])
408-
except (TypeError, ValueError):
409-
pass
410-
411-
if rowcount is not None or bytes_processed is not None:
412-
# if no rows were written, df contains 0 for bytes but no value for rows
413-
rowcount = (
414-
0 if rowcount is None and bytes_processed is not None else rowcount
415-
)
416-
417-
QueryExecutionTracker.record_execution(sql, rowcount, bytes_processed)
384+
history = (
385+
self.cursor.fetchdf()
386+
if isinstance(self.cursor, SparkSessionCursor)
387+
else self.cursor.fetchall_arrow()
388+
)
389+
if history is not None:
390+
from pandas import DataFrame as PandasDataFrame
391+
from pyspark.sql import DataFrame as PySparkDataFrame
392+
from pyspark.sql.connect.dataframe import DataFrame as PySparkConnectDataFrame
393+
394+
history_df = None
395+
if isinstance(history, PandasDataFrame):
396+
history_df = history
397+
elif isinstance(history, (PySparkDataFrame, PySparkConnectDataFrame)):
398+
history_df = history.toPandas()
399+
else:
400+
# arrow table
401+
history_df = history.to_pandas()
402+
403+
if history_df is not None and not history_df.empty:
404+
write_df = history_df[history_df["operation"] == "WRITE"]
405+
write_df = write_df[write_df["timestamp"] == write_df["timestamp"].max()]
406+
if not write_df.empty:
407+
metrics = write_df["operationMetrics"][0]
408+
if metrics:
409+
rowcount = None
410+
rowcount_str = [
411+
metric[1] for metric in metrics if metric[0] == "numOutputRows"
412+
]
413+
if rowcount_str:
414+
try:
415+
rowcount = int(rowcount_str[0])
416+
except (TypeError, ValueError):
417+
pass
418+
419+
bytes_processed = None
420+
bytes_str = [
421+
metric[1] for metric in metrics if metric[0] == "numOutputBytes"
422+
]
423+
if bytes_str:
424+
try:
425+
bytes_processed = int(bytes_str[0])
426+
except (TypeError, ValueError):
427+
pass
428+
429+
if rowcount is not None or bytes_processed is not None:
430+
# if no rows were written, df contains 0 for bytes but no value for rows
431+
rowcount = (
432+
0
433+
if rowcount is None and bytes_processed is not None
434+
else rowcount
435+
)
436+
437+
QueryExecutionTracker.record_execution(
438+
sql, rowcount, bytes_processed
439+
)

tests/core/engine_adapter/integration/test_integration.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2462,23 +2462,21 @@ def capture_execution_stats(
24622462
assert actual_execution_stats["full_model"].total_bytes_processed is not None
24632463

24642464
# run that loads 0 rows in incremental model
2465-
actual_execution_stats = {}
2466-
with patch.object(
2467-
context.console, "update_snapshot_evaluation_progress", capture_execution_stats
2468-
):
2469-
with time_machine.travel(date.today() + timedelta(days=1)):
2470-
context.run()
2471-
2472-
if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING:
2473-
assert actual_execution_stats["incremental_model"].total_rows_processed == 0
2474-
# snowflake doesn't track rows for CTAS
2475-
assert actual_execution_stats["full_model"].total_rows_processed == (
2476-
None if ctx.mark.startswith("snowflake") else 3
2477-
)
2478-
2479-
if ctx.mark.startswith("bigquery") or ctx.mark.startswith("databricks"):
2480-
assert actual_execution_stats["incremental_model"].total_bytes_processed is not None
2481-
assert actual_execution_stats["full_model"].total_bytes_processed is not None
2465+
# - some cloud DBs error because time travel messes up token expiration
2466+
if not ctx.is_remote:
2467+
actual_execution_stats = {}
2468+
with patch.object(
2469+
context.console, "update_snapshot_evaluation_progress", capture_execution_stats
2470+
):
2471+
with time_machine.travel(date.today() + timedelta(days=1)):
2472+
context.run()
2473+
2474+
if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING:
2475+
assert actual_execution_stats["incremental_model"].total_rows_processed == 0
2476+
# snowflake doesn't track rows for CTAS
2477+
assert actual_execution_stats["full_model"].total_rows_processed == (
2478+
None if ctx.mark.startswith("snowflake") else 3
2479+
)
24822480

24832481
# make and validate unmodified dev environment
24842482
no_change_plan: Plan = context.plan_builder(

0 commit comments

Comments
 (0)