Skip to content

Commit b7754b5

Browse files
committed
Remove seed tracking, have snapshot evaluator own tracker instance
1 parent 29f0c27 commit b7754b5

File tree

6 files changed

+46
-57
lines changed

6 files changed

+46
-57
lines changed

sqlmesh/core/console.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4274,7 +4274,7 @@ def _calculate_annotation_str_len(
42744274
def _format_bytes(num_bytes: t.Optional[int]) -> str:
42754275
if num_bytes and num_bytes > 0:
42764276
if num_bytes < 1024:
4277-
return f"{num_bytes} Bytes"
4277+
return f"{num_bytes} bytes"
42784278

42794279
num_bytes_float = float(num_bytes) / 1024.0
42804280
for unit in ["KiB", "MiB", "GiB", "TiB", "PiB"]:

sqlmesh/core/execution_tracker.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import time
44
import typing as t
55
from contextlib import contextmanager
6-
from threading import local
6+
from threading import local, Lock
77
from dataclasses import dataclass, field
88

99

@@ -66,34 +66,32 @@ class QueryExecutionTracker:
6666

6767
_thread_local = local()
6868
_contexts: t.Dict[str, QueryExecutionContext] = {}
69+
_contexts_lock = Lock()
6970

70-
@classmethod
71-
def get_execution_context(cls, snapshot_id_batch: str) -> t.Optional[QueryExecutionContext]:
72-
return cls._contexts.get(snapshot_id_batch)
71+
def get_execution_context(self, snapshot_id_batch: str) -> t.Optional[QueryExecutionContext]:
72+
with self._contexts_lock:
73+
return self._contexts.get(snapshot_id_batch)
7374

7475
@classmethod
7576
def is_tracking(cls) -> bool:
7677
return getattr(cls._thread_local, "context", None) is not None
7778

78-
@classmethod
7979
@contextmanager
8080
def track_execution(
81-
cls, snapshot_id_batch: str, condition: bool = True
81+
self, snapshot_id_batch: str
8282
) -> t.Iterator[t.Optional[QueryExecutionContext]]:
8383
"""
8484
Context manager for tracking snapshot execution statistics.
8585
"""
86-
if not condition:
87-
yield None
88-
return
89-
9086
context = QueryExecutionContext(snapshot_batch_id=snapshot_id_batch)
91-
cls._thread_local.context = context
92-
cls._contexts[snapshot_id_batch] = context
87+
self._thread_local.context = context
88+
with self._contexts_lock:
89+
self._contexts[snapshot_id_batch] = context
90+
9391
try:
9492
yield context
9593
finally:
96-
cls._thread_local.context = None
94+
self._thread_local.context = None
9795

9896
@classmethod
9997
def record_execution(
@@ -103,8 +101,8 @@ def record_execution(
103101
if context is not None:
104102
context.add_execution(sql, row_count, bytes_processed)
105103

106-
@classmethod
107-
def get_execution_stats(cls, snapshot_id_batch: str) -> t.Optional[QueryExecutionStats]:
108-
context = cls.get_execution_context(snapshot_id_batch)
109-
cls._contexts.pop(snapshot_id_batch, None)
104+
def get_execution_stats(self, snapshot_id_batch: str) -> t.Optional[QueryExecutionStats]:
105+
with self._contexts_lock:
106+
context = self._contexts.get(snapshot_id_batch)
107+
self._contexts.pop(snapshot_id_batch, None)
110108
return context.get_execution_stats() if context else None

sqlmesh/core/scheduler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from sqlmesh.core import constants as c
88
from sqlmesh.core.console import Console, get_console
99
from sqlmesh.core.environment import EnvironmentNamingInfo, execute_environment_statements
10-
from sqlmesh.core.execution_tracker import QueryExecutionTracker
1110
from sqlmesh.core.macros import RuntimeStage
1211
from sqlmesh.core.model.definition import AuditResult
1312
from sqlmesh.core.node import IntervalUnit
@@ -463,7 +462,7 @@ def evaluate_node(node: SchedulingUnit) -> None:
463462
num_audits = len(audit_results)
464463
num_audits_failed = sum(1 for result in audit_results if result.count)
465464

466-
execution_stats = QueryExecutionTracker.get_execution_stats(
465+
execution_stats = self.snapshot_evaluator.execution_tracker.get_execution_stats(
467466
f"{snapshot.snapshot_id}_{batch_idx}"
468467
)
469468

sqlmesh/core/snapshot/evaluator.py

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def __init__(
130130
)
131131
self.selected_gateway = selected_gateway
132132
self.ddl_concurrent_tasks = ddl_concurrent_tasks
133+
self.execution_tracker = QueryExecutionTracker()
133134

134135
def evaluate(
135136
self,
@@ -158,9 +159,7 @@ def evaluate(
158159
Returns:
159160
The WAP ID of this evaluation if supported, None otherwise.
160161
"""
161-
with QueryExecutionTracker.track_execution(
162-
f"{snapshot.snapshot_id}_{batch_index}", condition=not snapshot.is_seed
163-
):
162+
with self.execution_tracker.track_execution(f"{snapshot.snapshot_id}_{batch_index}"):
164163
result = self._evaluate_snapshot(
165164
snapshot,
166165
start,
@@ -204,19 +203,16 @@ def evaluate_and_fetch(
204203
Returns:
205204
The result of the evaluation as a dataframe.
206205
"""
207-
with QueryExecutionTracker.track_execution(
208-
f"{snapshot.snapshot_id}_0", condition=not snapshot.is_seed
209-
):
210-
result = self._evaluate_snapshot(
211-
snapshot,
212-
start,
213-
end,
214-
execution_time,
215-
snapshots,
216-
limit=limit,
217-
deployability_index=deployability_index,
218-
**kwargs,
219-
)
206+
result = self._evaluate_snapshot(
207+
snapshot,
208+
start,
209+
end,
210+
execution_time,
211+
snapshots,
212+
limit=limit,
213+
deployability_index=deployability_index,
214+
**kwargs,
215+
)
220216
if result is None or isinstance(result, str):
221217
raise SQLMeshError(
222218
f"Unexpected result {result} when evaluating snapshot {snapshot.snapshot_id}."
@@ -903,18 +899,15 @@ def _create_snapshot(
903899
)
904900
continue
905901

906-
with QueryExecutionTracker.track_execution(
907-
f"{snapshot.snapshot_id}_0", condition=snapshot.is_seed
908-
):
909-
self._execute_create(
910-
snapshot=snapshot,
911-
table_name=snapshot.table_name(is_deployable=is_table_deployable),
912-
is_table_deployable=is_table_deployable,
913-
deployability_index=deployability_index,
914-
create_render_kwargs=create_render_kwargs,
915-
rendered_physical_properties=rendered_physical_properties,
916-
dry_run=dry_run,
917-
)
902+
self._execute_create(
903+
snapshot=snapshot,
904+
table_name=snapshot.table_name(is_deployable=is_table_deployable),
905+
is_table_deployable=is_table_deployable,
906+
deployability_index=deployability_index,
907+
create_render_kwargs=create_render_kwargs,
908+
rendered_physical_properties=rendered_physical_properties,
909+
dry_run=dry_run,
910+
)
918911

919912
if on_complete is not None:
920913
on_complete(snapshot)

tests/core/engine_adapter/integration/test_integration.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2403,13 +2403,10 @@ def capture_row_counts(
24032403
assert len(physical_layer_results.tables) == len(physical_layer_results.non_temp_tables) == 3
24042404

24052405
if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING:
2406-
assert len(actual_execution_stats) == 3
2407-
assert actual_execution_stats["seed_model"].total_rows_processed == 7
24082406
assert actual_execution_stats["incremental_model"].total_rows_processed == 7
24092407
assert actual_execution_stats["full_model"].total_rows_processed == 3
24102408

24112409
if ctx.mark.startswith("bigquery"):
2412-
assert actual_execution_stats["seed_model"].total_bytes_processed
24132410
assert actual_execution_stats["incremental_model"].total_bytes_processed
24142411
assert actual_execution_stats["full_model"].total_bytes_processed
24152412

tests/core/test_execution_tracker.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,17 @@
77

88
def test_execution_tracker_thread_isolation() -> None:
99
def worker(id: str, row_counts: list[int]) -> QueryExecutionStats:
10-
with QueryExecutionTracker.track_execution(id) as ctx:
11-
assert QueryExecutionTracker.is_tracking()
10+
with execution_tracker.track_execution(id) as ctx:
11+
assert execution_tracker.is_tracking()
1212

1313
for count in row_counts:
14-
QueryExecutionTracker.record_execution("SELECT 1", count, None)
14+
execution_tracker.record_execution("SELECT 1", count, None)
1515

1616
assert ctx is not None
1717
return ctx.get_execution_stats()
1818

19+
execution_tracker = QueryExecutionTracker()
20+
1921
with ThreadPoolExecutor() as executor:
2022
futures = [
2123
executor.submit(worker, "batch_A", [10, 5]),
@@ -24,9 +26,9 @@ def worker(id: str, row_counts: list[int]) -> QueryExecutionStats:
2426
results = [f.result() for f in futures]
2527

2628
# Main thread has no active tracking context
27-
assert not QueryExecutionTracker.is_tracking()
28-
QueryExecutionTracker.record_execution("q", 10, None)
29-
assert QueryExecutionTracker.get_execution_stats("q") is None
29+
assert not execution_tracker.is_tracking()
30+
execution_tracker.record_execution("q", 10, None)
31+
assert execution_tracker.get_execution_stats("q") is None
3032

3133
# Order of results is not deterministic, so look up by id
3234
by_batch = {s.snapshot_batch_id: s for s in results}

0 commit comments

Comments
 (0)