Skip to content

Commit 2bbbe75

Browse files
committed
Make tracking fully instance-based by passing to engine adapter
1 parent db350b9 commit 2bbbe75

File tree

7 files changed

+42
-77
lines changed

7 files changed

+42
-77
lines changed

.circleci/continue_config.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,8 @@ workflows:
297297
name: cloud_engine_<< matrix.engine >>
298298
context:
299299
- sqlmesh_cloud_database_integration
300-
requires:
301-
- engine_tests_docker
300+
# requires:
301+
# - engine_tests_docker
302302
matrix:
303303
parameters:
304304
engine:
@@ -310,10 +310,10 @@ workflows:
310310
- athena
311311
- fabric
312312
- gcp-postgres
313-
filters:
314-
branches:
315-
only:
316-
- main
313+
# filters:
314+
# branches:
315+
# only:
316+
# - main
317317
- ui_style
318318
- ui_test
319319
- vscode_test

sqlmesh/core/engine_adapter/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def __init__(
135135
shared_connection: bool = False,
136136
correlation_id: t.Optional[CorrelationId] = None,
137137
schema_differ_overrides: t.Optional[t.Dict[str, t.Any]] = None,
138+
query_execution_tracker: t.Optional[QueryExecutionTracker] = None,
138139
**kwargs: t.Any,
139140
):
140141
self.dialect = dialect.lower() or self.DIALECT
@@ -158,6 +159,7 @@ def __init__(
158159
self._multithreaded = multithreaded
159160
self.correlation_id = correlation_id
160161
self._schema_differ_overrides = schema_differ_overrides
162+
self._query_execution_tracker = query_execution_tracker
161163

162164
def with_settings(self, **kwargs: t.Any) -> EngineAdapter:
163165
extra_kwargs = {
@@ -2448,15 +2450,17 @@ def _log_sql(
24482450
def _record_execution_stats(
24492451
self, sql: str, rowcount: t.Optional[int] = None, bytes_processed: t.Optional[int] = None
24502452
) -> None:
2451-
QueryExecutionTracker.record_execution(sql, rowcount, bytes_processed)
2453+
if self._query_execution_tracker:
2454+
self._query_execution_tracker.record_execution(sql, rowcount, bytes_processed)
24522455

24532456
def _execute(self, sql: str, track_rows_processed: bool = False, **kwargs: t.Any) -> None:
24542457
self.cursor.execute(sql, **kwargs)
24552458

24562459
if (
24572460
self.SUPPORTS_QUERY_EXECUTION_TRACKING
24582461
and track_rows_processed
2459-
and QueryExecutionTracker.is_tracking()
2462+
and self._query_execution_tracker
2463+
and self._query_execution_tracker.is_tracking()
24602464
):
24612465
if (
24622466
rowcount := getattr(self.cursor, "rowcount", None)

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
)
2424
from sqlmesh.core.node import IntervalUnit
2525
from sqlmesh.core.schema_diff import TableAlterOperation, NestedSupport
26-
from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker
2726
from sqlmesh.utils import optional_import, get_source_columns_to_types
2827
from sqlmesh.utils.date import to_datetime
2928
from sqlmesh.utils.errors import SQLMeshError
@@ -1097,7 +1096,11 @@ def _execute(
10971096
self.cursor._set_rowcount(query_results)
10981097
self.cursor._set_description(query_results.schema)
10991098

1100-
if track_rows_processed and QueryExecutionTracker.is_tracking():
1099+
if (
1100+
track_rows_processed
1101+
and self._query_execution_tracker
1102+
and self._query_execution_tracker.is_tracking()
1103+
):
11011104
num_rows = None
11021105
if query_job.statement_type == "CREATE_TABLE_AS_SELECT":
11031106
# since table was just created, number rows in table == number rows processed
@@ -1106,7 +1109,9 @@ def _execute(
11061109
elif query_job.statement_type in ["INSERT", "DELETE", "MERGE", "UPDATE"]:
11071110
num_rows = query_job.num_dml_affected_rows
11081111

1109-
QueryExecutionTracker.record_execution(sql, num_rows, query_job.total_bytes_processed)
1112+
self._query_execution_tracker.record_execution(
1113+
sql, num_rows, query_job.total_bytes_processed
1114+
)
11101115

11111116
def _get_data_objects(
11121117
self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None

sqlmesh/core/engine_adapter/snowflake.py

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import contextlib
44
import logging
5-
import re
65
import typing as t
76

87
from sqlglot import exp
@@ -24,7 +23,6 @@
2423
SourceQuery,
2524
set_catalog,
2625
)
27-
from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker
2826
from sqlmesh.utils import optional_import, get_source_columns_to_types
2927
from sqlmesh.utils.errors import SQLMeshError
3028
from sqlmesh.utils.pandas import columns_to_types_from_dtypes
@@ -189,7 +187,7 @@ def _create_table(
189187
table_description=table_description,
190188
column_descriptions=column_descriptions,
191189
table_kind=table_kind,
192-
track_rows_processed=track_rows_processed,
190+
track_rows_processed=False,
193191
**kwargs,
194192
)
195193

@@ -667,41 +665,3 @@ def close(self) -> t.Any:
667665
self._connection_pool.set_attribute(self.SNOWPARK, None)
668666

669667
return super().close()
670-
671-
def _record_execution_stats(
672-
self, sql: str, rowcount: t.Optional[int] = None, bytes_processed: t.Optional[int] = None
673-
) -> None:
674-
"""Snowflake does not report row counts for CTAS like other DML operations.
675-
676-
They neither report the sentinel value -1 nor do they report 0 rows. Instead, they report a rowcount
677-
of 1 and return a single data row containing one of the strings:
678-
- "Table <table_name> successfully created."
679-
- "<table_name> already exists, statement succeeded."
680-
681-
We do not want to record the incorrect row count of 1, so we check whether that row contains the table
682-
successfully created string. If so, we return early and do not record the row count.
683-
684-
Ref: https://github.com/snowflakedb/snowflake-connector-python/issues/645
685-
"""
686-
if rowcount == 1:
687-
results = self.cursor.fetchone()
688-
if results:
689-
try:
690-
results_str = str(results[0])
691-
except (TypeError, ValueError, IndexError):
692-
return
693-
694-
# Snowflake identifiers may be:
695-
# - An unquoted contiguous set of [a-zA-Z0-9_$] characters
696-
# - A double-quoted string that may contain spaces and nested double-quotes represented by `""`. Example: " my ""table"" name "
697-
is_created = re.match(
698-
r'Table [a-zA-Z0-9_$ "]*? successfully created\.', results_str
699-
)
700-
is_already_exists = re.match(
701-
r'[a-zA-Z0-9_$ "]*? already exists, statement succeeded\.',
702-
results_str,
703-
)
704-
if is_created or is_already_exists:
705-
return
706-
707-
QueryExecutionTracker.record_execution(sql, rowcount, bytes_processed)

sqlmesh/core/snapshot/evaluator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,18 @@ def __init__(
130130
self.adapters = (
131131
adapters if isinstance(adapters, t.Dict) else {selected_gateway or "": adapters}
132132
)
133+
self.execution_tracker = QueryExecutionTracker()
134+
self.adapters = {
135+
gateway: adapter.with_settings(query_execution_tracker=self.execution_tracker)
136+
for gateway, adapter in self.adapters.items()
137+
}
133138
self.adapter = (
134139
next(iter(self.adapters.values()))
135140
if not selected_gateway
136141
else self.adapters[selected_gateway]
137142
)
138143
self.selected_gateway = selected_gateway
139144
self.ddl_concurrent_tasks = ddl_concurrent_tasks
140-
self.execution_tracker = QueryExecutionTracker()
141145

142146
def evaluate(
143147
self,

sqlmesh/core/snapshot/execution_tracker.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import typing as t
44
from contextlib import contextmanager
5-
from threading import local, Lock
5+
from threading import local
66
from dataclasses import dataclass, field
77
from sqlmesh.core.snapshot import SnapshotIdBatch
88

@@ -56,19 +56,17 @@ def get_execution_stats(self) -> QueryExecutionStats:
5656
class QueryExecutionTracker:
5757
"""Thread-local context manager for snapshot execution statistics, such as rows processed."""
5858

59-
_thread_local = local()
60-
_contexts: t.Dict[SnapshotIdBatch, QueryExecutionContext] = {}
61-
_contexts_lock = Lock()
59+
def __init__(self) -> None:
60+
self._thread_local = local()
61+
self._contexts: t.Dict[SnapshotIdBatch, QueryExecutionContext] = {}
6262

6363
def get_execution_context(
6464
self, snapshot_id_batch: SnapshotIdBatch
6565
) -> t.Optional[QueryExecutionContext]:
66-
with self._contexts_lock:
67-
return self._contexts.get(snapshot_id_batch)
66+
return self._contexts.get(snapshot_id_batch)
6867

69-
@classmethod
70-
def is_tracking(cls) -> bool:
71-
return getattr(cls._thread_local, "context", None) is not None
68+
def is_tracking(self) -> bool:
69+
return getattr(self._thread_local, "context", None) is not None
7270

7371
@contextmanager
7472
def track_execution(
@@ -77,26 +75,23 @@ def track_execution(
7775
"""Context manager for tracking snapshot execution statistics such as row counts and bytes processed."""
7876
context = QueryExecutionContext(snapshot_id_batch=snapshot_id_batch)
7977
self._thread_local.context = context
80-
with self._contexts_lock:
81-
self._contexts[snapshot_id_batch] = context
78+
self._contexts[snapshot_id_batch] = context
8279

8380
try:
8481
yield context
8582
finally:
8683
self._thread_local.context = None
8784

88-
@classmethod
8985
def record_execution(
90-
cls, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int]
86+
self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int]
9187
) -> None:
92-
context = getattr(cls._thread_local, "context", None)
88+
context = getattr(self._thread_local, "context", None)
9389
if context is not None:
9490
context.add_execution(sql, row_count, bytes_processed)
9591

9692
def get_execution_stats(
9793
self, snapshot_id_batch: SnapshotIdBatch
9894
) -> t.Optional[QueryExecutionStats]:
99-
with self._contexts_lock:
100-
context = self._contexts.get(snapshot_id_batch)
101-
self._contexts.pop(snapshot_id_batch, None)
95+
context = self._contexts.get(snapshot_id_batch)
96+
self._contexts.pop(snapshot_id_batch, None)
10297
return context.get_execution_stats() if context else None

tests/core/engine_adapter/integration/test_integration_snowflake.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -327,16 +327,13 @@ def test_rows_tracker(
327327
SnapshotIdBatch(snapshot_id=SnapshotId(name="a", identifier="a"), batch_id=0)
328328
):
329329
# Snowflake doesn't report row counts for CTAS, so this should not be tracked
330-
engine_adapter.execute(
331-
"CREATE TABLE a (id int) AS SELECT 1 as id", track_rows_processed=True
332-
)
333-
engine_adapter.execute("INSERT INTO a VALUES (2), (3)", track_rows_processed=True)
334-
engine_adapter.execute("INSERT INTO a VALUES (4)", track_rows_processed=True)
330+
engine_adapter._create_table("a", exp.select("1 as id"))
335331

336-
assert add_execution_spy.call_count == 2
332+
assert add_execution_spy.call_count == 0
337333

338334
stats = tracker.get_execution_stats(
339335
SnapshotIdBatch(snapshot_id=SnapshotId(name="a", identifier="a"), batch_id=0)
340336
)
341337
assert stats is not None
342-
assert stats.total_rows_processed == 3
338+
assert stats.total_rows_processed is None
339+
assert stats.total_bytes_processed is None

0 commit comments

Comments
 (0)