Skip to content

Commit 1e405ce

Browse files
committed
Move all tracking into snapshot evaluator, remove seed tracker class
1 parent a78e5d1 commit 1e405ce

File tree

6 files changed

+141
-225
lines changed

6 files changed

+141
-225
lines changed

sqlmesh/core/engine_adapter/base.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
)
4141
from sqlmesh.core.model.kind import TimeColumn
4242
from sqlmesh.core.schema_diff import SchemaDiffer
43-
from sqlmesh.core.execution_tracker import record_execution as track_execution_record
43+
from sqlmesh.core.execution_tracker import QueryExecutionTracker
4444
from sqlmesh.utils import (
4545
CorrelationId,
4646
columns_to_types_all_known,
@@ -2404,7 +2404,11 @@ def _log_sql(
24042404
def _execute(self, sql: str, track_row_count: bool = False, **kwargs: t.Any) -> None:
24052405
self.cursor.execute(sql, **kwargs)
24062406

2407-
if track_row_count and self.SUPPORTS_QUERY_EXECUTION_TRACKING:
2407+
if (
2408+
self.SUPPORTS_QUERY_EXECUTION_TRACKING
2409+
and track_row_count
2410+
and QueryExecutionTracker.is_tracking()
2411+
):
24082412
rowcount_raw = getattr(self.cursor, "rowcount", None)
24092413
rowcount = None
24102414
if rowcount_raw is not None:
@@ -2413,7 +2417,7 @@ def _execute(self, sql: str, track_row_count: bool = False, **kwargs: t.Any) ->
24132417
except (TypeError, ValueError):
24142418
pass
24152419

2416-
track_execution_record(sql, rowcount)
2420+
QueryExecutionTracker.record_execution(sql, rowcount)
24172421

24182422
@contextlib.contextmanager
24192423
def temp_table(

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
SourceQuery,
2121
set_catalog,
2222
)
23-
from sqlmesh.core.execution_tracker import record_execution as track_execution_record
23+
from sqlmesh.core.execution_tracker import QueryExecutionTracker
2424
from sqlmesh.core.node import IntervalUnit
2525
from sqlmesh.core.schema_diff import SchemaDiffer
2626
from sqlmesh.utils import optional_import, get_source_columns_to_types
@@ -1104,7 +1104,7 @@ def _execute(
11041104
elif query_job.statement_type in ["INSERT", "DELETE", "MERGE", "UPDATE"]:
11051105
num_rows = query_job.num_dml_affected_rows
11061106

1107-
track_execution_record(sql, num_rows)
1107+
QueryExecutionTracker.record_execution(sql, num_rows)
11081108

11091109
def _get_data_objects(
11101110
self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None

sqlmesh/core/execution_tracker.py

Lines changed: 19 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ class QueryExecutionContext:
2727
queries_executed: t.List[t.Tuple[str, t.Optional[int], float]] = field(default_factory=list)
2828

2929
def add_execution(self, sql: str, row_count: t.Optional[int]) -> None:
30-
"""Record a single query execution."""
3130
if row_count is not None and row_count >= 0:
3231
self.total_rows_processed += row_count
3332
self.query_count += 1
@@ -46,28 +45,36 @@ def get_execution_stats(self) -> t.Dict[str, t.Any]:
4645

4746
class QueryExecutionTracker:
4847
"""
49-
Thread-local context manager for snapshot evaluation execution statistics, such as
48+
Thread-local context manager for snapshot execution statistics, such as
5049
rows processed.
5150
"""
5251

5352
_thread_local = local()
53+
_contexts: t.Dict[str, QueryExecutionContext] = {}
5454

5555
@classmethod
56-
def get_execution_context(cls) -> t.Optional[QueryExecutionContext]:
57-
return getattr(cls._thread_local, "context", None)
56+
def get_execution_context(cls, snapshot_id_batch: str) -> t.Optional[QueryExecutionContext]:
57+
return cls._contexts.get(snapshot_id_batch)
5858

5959
@classmethod
6060
def is_tracking(cls) -> bool:
61-
return cls.get_execution_context() is not None
61+
return getattr(cls._thread_local, "context", None) is not None
6262

6363
@classmethod
6464
@contextmanager
65-
def track_execution(cls, snapshot_name_batch: str) -> t.Iterator[QueryExecutionContext]:
65+
def track_execution(
66+
cls, snapshot_id_batch: str, condition: bool = True
67+
) -> t.Iterator[t.Optional[QueryExecutionContext]]:
6668
"""
67-
Context manager for tracking snapshot evaluation execution statistics.
69+
Context manager for tracking snapshot execution statistics.
6870
"""
69-
context = QueryExecutionContext(id=snapshot_name_batch)
71+
if not condition:
72+
yield None
73+
return
74+
75+
context = QueryExecutionContext(id=snapshot_id_batch)
7076
cls._thread_local.context = context
77+
cls._contexts[snapshot_id_batch] = context
7178
try:
7279
yield context
7380
finally:
@@ -76,67 +83,12 @@ def track_execution(cls, snapshot_name_batch: str) -> t.Iterator[QueryExecutionC
7683

7784
@classmethod
7885
def record_execution(cls, sql: str, row_count: t.Optional[int]) -> None:
79-
context = cls.get_execution_context()
86+
context = getattr(cls._thread_local, "context", None)
8087
if context is not None:
8188
context.add_execution(sql, row_count)
8289

8390
@classmethod
84-
def get_execution_stats(cls) -> t.Optional[t.Dict[str, t.Any]]:
85-
context = cls.get_execution_context()
91+
def get_execution_stats(cls, snapshot_id_batch: str) -> t.Optional[t.Dict[str, t.Any]]:
92+
context = cls.get_execution_context(snapshot_id_batch)
93+
cls._contexts.pop(snapshot_id_batch, None)
8694
return context.get_execution_stats() if context else None
87-
88-
89-
class SeedExecutionTracker:
90-
_seed_contexts: t.Dict[str, QueryExecutionContext] = {}
91-
_thread_local = local()
92-
93-
@classmethod
94-
@contextmanager
95-
def track_execution(cls, model_name: str) -> t.Iterator[QueryExecutionContext]:
96-
"""
97-
Context manager for tracking seed creation execution statistics.
98-
"""
99-
context = QueryExecutionContext(id=model_name)
100-
cls._seed_contexts[model_name] = context
101-
cls._thread_local.seed_id = model_name
102-
103-
try:
104-
yield context
105-
finally:
106-
if hasattr(cls._thread_local, "seed_id"):
107-
delattr(cls._thread_local, "seed_id")
108-
109-
@classmethod
110-
def get_and_clear_seed_stats(cls, model_name: str) -> t.Optional[t.Dict[str, t.Any]]:
111-
context = cls._seed_contexts.pop(model_name, None)
112-
return context.get_execution_stats() if context else None
113-
114-
@classmethod
115-
def clear_all_seed_stats(cls) -> None:
116-
"""Clear all remaining seed stats. Used for cleanup after evaluation completes."""
117-
cls._seed_contexts.clear()
118-
119-
@classmethod
120-
def is_tracking(cls) -> bool:
121-
return hasattr(cls._thread_local, "seed_id")
122-
123-
@classmethod
124-
def record_execution(cls, sql: str, row_count: t.Optional[int]) -> None:
125-
seed_id = getattr(cls._thread_local, "seed_id", None)
126-
if seed_id:
127-
context = cls._seed_contexts.get(seed_id)
128-
if context is not None:
129-
context.add_execution(sql, row_count)
130-
131-
132-
def record_execution(sql: str, row_count: t.Optional[int]) -> None:
133-
"""
134-
Record execution statistics for a single SQL statement.
135-
136-
Automatically infers which tracker is active based on the current thread.
137-
"""
138-
if SeedExecutionTracker.is_tracking():
139-
SeedExecutionTracker.record_execution(sql, row_count)
140-
return
141-
if QueryExecutionTracker.is_tracking():
142-
QueryExecutionTracker.record_execution(sql, row_count)

sqlmesh/core/scheduler.py

Lines changed: 52 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from sqlmesh.core import constants as c
88
from sqlmesh.core.console import Console, get_console
99
from sqlmesh.core.environment import EnvironmentNamingInfo, execute_environment_statements
10-
from sqlmesh.core.execution_tracker import QueryExecutionTracker, SeedExecutionTracker
10+
from sqlmesh.core.execution_tracker import QueryExecutionTracker
1111
from sqlmesh.core.macros import RuntimeStage
1212
from sqlmesh.core.model.definition import AuditResult
1313
from sqlmesh.core.node import IntervalUnit
@@ -427,69 +427,59 @@ def evaluate_node(node: SchedulingUnit) -> None:
427427
return
428428
snapshot = self.snapshots_by_name[snapshot_name]
429429

430-
with QueryExecutionTracker.track_execution(
431-
f"{snapshot.name}_{batch_idx}"
432-
) as execution_context:
433-
self.console.start_snapshot_evaluation_progress(snapshot)
434-
435-
execution_start_ts = now_timestamp()
436-
evaluation_duration_ms: t.Optional[int] = None
437-
438-
audit_results: t.List[AuditResult] = []
439-
try:
440-
assert execution_time # mypy
441-
assert deployability_index # mypy
442-
443-
if audit_only:
444-
audit_results = self._audit_snapshot(
445-
snapshot=snapshot,
446-
environment_naming_info=environment_naming_info,
447-
deployability_index=deployability_index,
448-
snapshots=self.snapshots_by_name,
449-
start=start,
450-
end=end,
451-
execution_time=execution_time,
452-
)
453-
else:
454-
audit_results = self.evaluate(
455-
snapshot=snapshot,
456-
environment_naming_info=environment_naming_info,
457-
start=start,
458-
end=end,
459-
execution_time=execution_time,
460-
deployability_index=deployability_index,
461-
batch_index=batch_idx,
462-
)
463-
464-
evaluation_duration_ms = now_timestamp() - execution_start_ts
465-
finally:
466-
num_audits = len(audit_results)
467-
num_audits_failed = sum(1 for result in audit_results if result.count)
468-
469-
rows_processed = None
470-
if snapshot.is_seed:
471-
# seed stats are tracked in SeedStrategy.create by model name, not snapshot name
472-
seed_stats = SeedExecutionTracker.get_and_clear_seed_stats(
473-
snapshot.model.name
474-
)
475-
rows_processed = (
476-
seed_stats.get("total_rows_processed") if seed_stats else None
477-
)
478-
else:
479-
rows_processed = (
480-
execution_context.total_rows_processed if execution_context else None
481-
)
482-
483-
self.console.update_snapshot_evaluation_progress(
484-
snapshot,
485-
batched_intervals[snapshot][batch_idx],
486-
batch_idx,
487-
evaluation_duration_ms,
488-
num_audits - num_audits_failed,
489-
num_audits_failed,
490-
rows_processed=rows_processed,
430+
self.console.start_snapshot_evaluation_progress(snapshot)
431+
432+
execution_start_ts = now_timestamp()
433+
evaluation_duration_ms: t.Optional[int] = None
434+
435+
audit_results: t.List[AuditResult] = []
436+
try:
437+
assert execution_time # mypy
438+
assert deployability_index # mypy
439+
440+
if audit_only:
441+
audit_results = self._audit_snapshot(
442+
snapshot=snapshot,
443+
environment_naming_info=environment_naming_info,
444+
deployability_index=deployability_index,
445+
snapshots=self.snapshots_by_name,
446+
start=start,
447+
end=end,
448+
execution_time=execution_time,
449+
)
450+
else:
451+
audit_results = self.evaluate(
452+
snapshot=snapshot,
453+
environment_naming_info=environment_naming_info,
454+
start=start,
455+
end=end,
456+
execution_time=execution_time,
457+
deployability_index=deployability_index,
458+
batch_index=batch_idx,
491459
)
492460

461+
evaluation_duration_ms = now_timestamp() - execution_start_ts
462+
finally:
463+
num_audits = len(audit_results)
464+
num_audits_failed = sum(1 for result in audit_results if result.count)
465+
466+
execution_stats = QueryExecutionTracker.get_execution_stats(
467+
f"{snapshot.snapshot_id}_{batch_idx}"
468+
)
469+
rows_processed = (
470+
execution_stats["total_rows_processed"] if execution_stats else None
471+
)
472+
473+
self.console.update_snapshot_evaluation_progress(
474+
snapshot,
475+
batched_intervals[snapshot][batch_idx],
476+
batch_idx,
477+
evaluation_duration_ms,
478+
num_audits - num_audits_failed,
479+
num_audits_failed,
480+
rows_processed=rows_processed,
481+
)
482+
493483
try:
494484
with self.snapshot_evaluator.concurrent_context():
495485
errors, skipped_intervals = concurrent_apply_to_dag(
@@ -529,9 +519,6 @@ def evaluate_node(node: SchedulingUnit) -> None:
529519

530520
self.state_sync.recycle()
531521

532-
# Clean up any remaining seed execution stats
533-
SeedExecutionTracker.clear_all_seed_stats()
534-
535522
def _dag(self, batches: SnapshotToIntervals) -> DAG[SchedulingUnit]:
536523
"""Builds a DAG of snapshot intervals to be evaluated.
537524

0 commit comments

Comments
 (0)