22
33import typing as t
44from contextlib import contextmanager
5- from threading import local , Lock
5+ from threading import local
66from dataclasses import dataclass , field
77from sqlmesh .core .snapshot import SnapshotIdBatch
88
@@ -56,19 +56,17 @@ def get_execution_stats(self) -> QueryExecutionStats:
5656class QueryExecutionTracker :
5757 """Thread-local context manager for snapshot execution statistics, such as rows processed."""
5858
59- _thread_local = local ()
60- _contexts : t . Dict [ SnapshotIdBatch , QueryExecutionContext ] = {}
61- _contexts_lock = Lock ()
59+ def __init__ ( self ) -> None :
60+ self . _thread_local = local ()
61+ self . _contexts : t . Dict [ SnapshotIdBatch , QueryExecutionContext ] = {}
6262
6363 def get_execution_context (
6464 self , snapshot_id_batch : SnapshotIdBatch
6565 ) -> t .Optional [QueryExecutionContext ]:
66- with self ._contexts_lock :
67- return self ._contexts .get (snapshot_id_batch )
66+ return self ._contexts .get (snapshot_id_batch )
6867
69- @classmethod
70- def is_tracking (cls ) -> bool :
71- return getattr (cls ._thread_local , "context" , None ) is not None
68+ def is_tracking (self ) -> bool :
69+ return getattr (self ._thread_local , "context" , None ) is not None
7270
7371 @contextmanager
7472 def track_execution (
@@ -77,26 +75,23 @@ def track_execution(
7775 """Context manager for tracking snapshot execution statistics such as row counts and bytes processed."""
7876 context = QueryExecutionContext (snapshot_id_batch = snapshot_id_batch )
7977 self ._thread_local .context = context
80- with self ._contexts_lock :
81- self ._contexts [snapshot_id_batch ] = context
78+ self ._contexts [snapshot_id_batch ] = context
8279
8380 try :
8481 yield context
8582 finally :
8683 self ._thread_local .context = None
8784
88- @classmethod
8985 def record_execution (
90- cls , sql : str , row_count : t .Optional [int ], bytes_processed : t .Optional [int ]
86+ self , sql : str , row_count : t .Optional [int ], bytes_processed : t .Optional [int ]
9187 ) -> None :
92- context = getattr (cls ._thread_local , "context" , None )
88+ context = getattr (self ._thread_local , "context" , None )
9389 if context is not None :
9490 context .add_execution (sql , row_count , bytes_processed )
9591
9692 def get_execution_stats (
9793 self , snapshot_id_batch : SnapshotIdBatch
9894 ) -> t .Optional [QueryExecutionStats ]:
99- with self ._contexts_lock :
100- context = self ._contexts .get (snapshot_id_batch )
101- self ._contexts .pop (snapshot_id_batch , None )
95+ context = self ._contexts .get (snapshot_id_batch )
96+ self ._contexts .pop (snapshot_id_batch , None )
10297 return context .get_execution_stats () if context else None
0 commit comments