33import time
44import typing as t
55from contextlib import contextmanager
6- from threading import get_ident , Lock
6+ from threading import local
77from dataclasses import dataclass , field
88
99
@@ -50,14 +50,11 @@ class QueryExecutionTracker:
5050 rows processed.
5151 """
5252
53- _thread_contexts : t .Dict [int , QueryExecutionContext ] = {}
54- _contexts_lock = Lock ()
53+ _thread_local = local ()
5554
5655 @classmethod
5756 def get_execution_context (cls ) -> t .Optional [QueryExecutionContext ]:
58- thread_id = get_ident ()
59- with cls ._contexts_lock :
60- return cls ._thread_contexts .get (thread_id )
57+ return getattr (cls ._thread_local , "context" , None )
6158
6259 @classmethod
6360 def is_tracking (cls ) -> bool :
@@ -70,23 +67,18 @@ def track_execution(cls, snapshot_name_batch: str) -> t.Iterator[QueryExecutionC
7067 Context manager for tracking snapshot evaluation execution statistics.
7168 """
7269 context = QueryExecutionContext (id = snapshot_name_batch )
73- thread_id = get_ident ()
74-
75- with cls ._contexts_lock :
76- cls ._thread_contexts [thread_id ] = context
70+ cls ._thread_local .context = context
7771 try :
7872 yield context
7973 finally :
80- with cls ._contexts_lock :
81- cls ._thread_contexts . pop ( thread_id , None )
74+ if hasattr ( cls ._thread_local , "context" ) :
75+ delattr ( cls ._thread_local , "context" )
8276
8377 @classmethod
8478 def record_execution (cls , sql : str , row_count : t .Optional [int ]) -> None :
85- thread_id = get_ident ()
86- with cls ._contexts_lock :
87- context = cls ._thread_contexts .get (thread_id )
88- if context is not None :
89- context .add_execution (sql , row_count )
79+ context = cls .get_execution_context ()
80+ if context is not None :
81+ context .add_execution (sql , row_count )
9082
9183 @classmethod
9284 def get_execution_stats (cls ) -> t .Optional [t .Dict [str , t .Any ]]:
@@ -96,57 +88,42 @@ def get_execution_stats(cls) -> t.Optional[t.Dict[str, t.Any]]:
9688
9789class SeedExecutionTracker :
9890 _seed_contexts : t .Dict [str , QueryExecutionContext ] = {}
99- _active_threads : t .Set [int ] = set ()
100- _thread_to_seed_id : t .Dict [int , str ] = {}
101- _seed_contexts_lock = Lock ()
91+ _thread_local = local ()
10292
10393 @classmethod
10494 @contextmanager
10595 def track_execution (cls , model_name : str ) -> t .Iterator [QueryExecutionContext ]:
10696 """
10797 Context manager for tracking seed creation execution statistics.
10898 """
109-
11099 context = QueryExecutionContext (id = model_name )
111- thread_id = get_ident ()
112-
113- with cls ._seed_contexts_lock :
114- cls ._seed_contexts [model_name ] = context
115- cls ._active_threads .add (thread_id )
116- cls ._thread_to_seed_id [thread_id ] = model_name
100+ cls ._seed_contexts [model_name ] = context
101+ cls ._thread_local .seed_id = model_name
117102
118103 try :
119104 yield context
120105 finally :
121- with cls ._seed_contexts_lock :
122- cls ._active_threads .discard (thread_id )
123- cls ._thread_to_seed_id .pop (thread_id , None )
106+ if hasattr (cls ._thread_local , "seed_id" ):
107+ delattr (cls ._thread_local , "seed_id" )
124108
125109 @classmethod
126110 def get_and_clear_seed_stats (cls , model_name : str ) -> t .Optional [t .Dict [str , t .Any ]]:
127- with cls ._seed_contexts_lock :
128- context = cls ._seed_contexts .pop (model_name , None )
129- return context .get_execution_stats () if context else None
111+ context = cls ._seed_contexts .pop (model_name , None )
112+ return context .get_execution_stats () if context else None
130113
131114 @classmethod
132115 def clear_all_seed_stats (cls ) -> None :
133116 """Clear all remaining seed stats. Used for cleanup after evaluation completes."""
134- with cls ._seed_contexts_lock :
135- cls ._seed_contexts .clear ()
117+ cls ._seed_contexts .clear ()
136118
137119 @classmethod
138120 def is_tracking (cls ) -> bool :
139- thread_id = get_ident ()
140- with cls ._seed_contexts_lock :
141- return thread_id in cls ._active_threads
121+ return hasattr (cls ._thread_local , "seed_id" )
142122
143123 @classmethod
144124 def record_execution (cls , sql : str , row_count : t .Optional [int ]) -> None :
145- thread_id = get_ident ()
146- with cls ._seed_contexts_lock :
147- seed_id = cls ._thread_to_seed_id .get (thread_id )
148- if not seed_id :
149- return
125+ seed_id = getattr (cls ._thread_local , "seed_id" , None )
126+ if seed_id :
150127 context = cls ._seed_contexts .get (seed_id )
151128 if context is not None :
152129 context .add_execution (sql , row_count )
0 commit comments