@@ -27,7 +27,6 @@ class QueryExecutionContext:
2727 queries_executed : t .List [t .Tuple [str , t .Optional [int ], float ]] = field (default_factory = list )
2828
2929 def add_execution (self , sql : str , row_count : t .Optional [int ]) -> None :
30- """Record a single query execution."""
3130 if row_count is not None and row_count >= 0 :
3231 self .total_rows_processed += row_count
3332 self .query_count += 1
@@ -46,28 +45,36 @@ def get_execution_stats(self) -> t.Dict[str, t.Any]:
4645
4746class QueryExecutionTracker :
4847 """
49- Thread-local context manager for snapshot evaluation execution statistics, such as
48+ Thread-local context manager for snapshot execution statistics, such as
5049 rows processed.
5150 """
5251
5352 _thread_local = local ()
53+ _contexts : t .Dict [str , QueryExecutionContext ] = {}
5454
5555 @classmethod
56- def get_execution_context (cls ) -> t .Optional [QueryExecutionContext ]:
57- return getattr ( cls ._thread_local , "context" , None )
56+ def get_execution_context (cls , snapshot_id_batch : str ) -> t .Optional [QueryExecutionContext ]:
57+ return cls ._contexts . get ( snapshot_id_batch )
5858
5959 @classmethod
6060 def is_tracking (cls ) -> bool :
61- return cls .get_execution_context ( ) is not None
61+ return getattr ( cls ._thread_local , "context" , None ) is not None
6262
6363 @classmethod
6464 @contextmanager
65- def track_execution (cls , snapshot_name_batch : str ) -> t .Iterator [QueryExecutionContext ]:
65+ def track_execution (
66+ cls , snapshot_id_batch : str , condition : bool = True
67+ ) -> t .Iterator [t .Optional [QueryExecutionContext ]]:
6668 """
67- Context manager for tracking snapshot evaluation execution statistics.
69+ Context manager for tracking snapshot execution statistics.
6870 """
69- context = QueryExecutionContext (id = snapshot_name_batch )
71+ if not condition :
72+ yield None
73+ return
74+
75+ context = QueryExecutionContext (id = snapshot_id_batch )
7076 cls ._thread_local .context = context
77+ cls ._contexts [snapshot_id_batch ] = context
7178 try :
7279 yield context
7380 finally :
@@ -76,67 +83,12 @@ def track_execution(cls, snapshot_name_batch: str) -> t.Iterator[QueryExecutionC
7683
7784 @classmethod
7885 def record_execution (cls , sql : str , row_count : t .Optional [int ]) -> None :
79- context = cls .get_execution_context ( )
86+ context = getattr ( cls ._thread_local , "context" , None )
8087 if context is not None :
8188 context .add_execution (sql , row_count )
8289
8390 @classmethod
84- def get_execution_stats (cls ) -> t .Optional [t .Dict [str , t .Any ]]:
85- context = cls .get_execution_context ()
91+ def get_execution_stats (cls , snapshot_id_batch : str ) -> t .Optional [t .Dict [str , t .Any ]]:
92+ context = cls .get_execution_context (snapshot_id_batch )
93+ cls ._contexts .pop (snapshot_id_batch , None )
8694 return context .get_execution_stats () if context else None
87-
88-
89- class SeedExecutionTracker :
90- _seed_contexts : t .Dict [str , QueryExecutionContext ] = {}
91- _thread_local = local ()
92-
93- @classmethod
94- @contextmanager
95- def track_execution (cls , model_name : str ) -> t .Iterator [QueryExecutionContext ]:
96- """
97- Context manager for tracking seed creation execution statistics.
98- """
99- context = QueryExecutionContext (id = model_name )
100- cls ._seed_contexts [model_name ] = context
101- cls ._thread_local .seed_id = model_name
102-
103- try :
104- yield context
105- finally :
106- if hasattr (cls ._thread_local , "seed_id" ):
107- delattr (cls ._thread_local , "seed_id" )
108-
109- @classmethod
110- def get_and_clear_seed_stats (cls , model_name : str ) -> t .Optional [t .Dict [str , t .Any ]]:
111- context = cls ._seed_contexts .pop (model_name , None )
112- return context .get_execution_stats () if context else None
113-
114- @classmethod
115- def clear_all_seed_stats (cls ) -> None :
116- """Clear all remaining seed stats. Used for cleanup after evaluation completes."""
117- cls ._seed_contexts .clear ()
118-
119- @classmethod
120- def is_tracking (cls ) -> bool :
121- return hasattr (cls ._thread_local , "seed_id" )
122-
123- @classmethod
124- def record_execution (cls , sql : str , row_count : t .Optional [int ]) -> None :
125- seed_id = getattr (cls ._thread_local , "seed_id" , None )
126- if seed_id :
127- context = cls ._seed_contexts .get (seed_id )
128- if context is not None :
129- context .add_execution (sql , row_count )
130-
131-
132- def record_execution (sql : str , row_count : t .Optional [int ]) -> None :
133- """
134- Record execution statistics for a single SQL statement.
135-
136- Automatically infers which tracker is active based on the current thread.
137- """
138- if SeedExecutionTracker .is_tracking ():
139- SeedExecutionTracker .record_execution (sql , row_count )
140- return
141- if QueryExecutionTracker .is_tracking ():
142- QueryExecutionTracker .record_execution (sql , row_count )
0 commit comments