44from contextlib import contextmanager
55from threading import local , Lock
66from dataclasses import dataclass , field
7+ from sqlmesh .core .snapshot import SnapshotIdBatch
78
89
910@dataclass
1011class QueryExecutionStats :
11- snapshot_batch_id : str
12+ snapshot_id_batch : SnapshotIdBatch
1213 total_rows_processed : t .Optional [int ] = None
1314 total_bytes_processed : t .Optional [int ] = None
1415
@@ -21,15 +22,15 @@ class QueryExecutionContext:
2122 It accumulates statistics from multiple cursor.execute() calls during a single snapshot evaluation.
2223
2324 Attributes:
24- snapshot_batch_id : Identifier linking this context to a specific snapshot evaluation
25+ snapshot_id_batch : Identifier linking this context to a specific snapshot evaluation
2526 stats: Running sum of cursor.rowcount and possibly bytes processed from all executed queries during evaluation
2627 """
2728
28- snapshot_batch_id : str
29+ snapshot_id_batch : SnapshotIdBatch
2930 stats : QueryExecutionStats = field (init = False )
3031
3132 def __post_init__ (self ) -> None :
32- self .stats = QueryExecutionStats (snapshot_batch_id = self .snapshot_batch_id )
33+ self .stats = QueryExecutionStats (snapshot_id_batch = self .snapshot_id_batch )
3334
3435 def add_execution (
3536 self , sql : str , row_count : t .Optional [int ], bytes_processed : t .Optional [int ]
@@ -56,10 +57,12 @@ class QueryExecutionTracker:
5657 """Thread-local context manager for snapshot execution statistics, such as rows processed."""
5758
5859 _thread_local = local ()
59- _contexts : t .Dict [str , QueryExecutionContext ] = {}
60+ _contexts : t .Dict [SnapshotIdBatch , QueryExecutionContext ] = {}
6061 _contexts_lock = Lock ()
6162
62- def get_execution_context (self , snapshot_id_batch : str ) -> t .Optional [QueryExecutionContext ]:
63+ def get_execution_context (
64+ self , snapshot_id_batch : SnapshotIdBatch
65+ ) -> t .Optional [QueryExecutionContext ]:
6366 with self ._contexts_lock :
6467 return self ._contexts .get (snapshot_id_batch )
6568
@@ -69,10 +72,10 @@ def is_tracking(cls) -> bool:
6972
7073 @contextmanager
7174 def track_execution (
72- self , snapshot_id_batch : str
75+ self , snapshot_id_batch : SnapshotIdBatch
7376 ) -> t .Iterator [t .Optional [QueryExecutionContext ]]:
7477 """Context manager for tracking snapshot execution statistics such as row counts and bytes processed."""
75- context = QueryExecutionContext (snapshot_batch_id = snapshot_id_batch )
78+ context = QueryExecutionContext (snapshot_id_batch = snapshot_id_batch )
7679 self ._thread_local .context = context
7780 with self ._contexts_lock :
7881 self ._contexts [snapshot_id_batch ] = context
@@ -90,7 +93,9 @@ def record_execution(
9093 if context is not None :
9194 context .add_execution (sql , row_count , bytes_processed )
9295
93- def get_execution_stats (self , snapshot_id_batch : str ) -> t .Optional [QueryExecutionStats ]:
96+ def get_execution_stats (
97+ self , snapshot_id_batch : SnapshotIdBatch
98+ ) -> t .Optional [QueryExecutionStats ]:
9499 with self ._contexts_lock :
95100 context = self ._contexts .get (snapshot_id_batch )
96101 self ._contexts .pop (snapshot_id_batch , None )
0 commit comments