Skip to content

Commit a78e5d1

Browse files
committed
Use threading.local() instead of locks
1 parent d93d3a6 commit a78e5d1

File tree

1 file changed

+20
-43
lines changed

1 file changed

+20
-43
lines changed

sqlmesh/core/execution_tracker.py

Lines changed: 20 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import time
44
import typing as t
55
from contextlib import contextmanager
6-
from threading import get_ident, Lock
6+
from threading import local
77
from dataclasses import dataclass, field
88

99

@@ -50,14 +50,11 @@ class QueryExecutionTracker:
5050
rows processed.
5151
"""
5252

53-
_thread_contexts: t.Dict[int, QueryExecutionContext] = {}
54-
_contexts_lock = Lock()
53+
_thread_local = local()
5554

5655
@classmethod
5756
def get_execution_context(cls) -> t.Optional[QueryExecutionContext]:
58-
thread_id = get_ident()
59-
with cls._contexts_lock:
60-
return cls._thread_contexts.get(thread_id)
57+
return getattr(cls._thread_local, "context", None)
6158

6259
@classmethod
6360
def is_tracking(cls) -> bool:
@@ -70,23 +67,18 @@ def track_execution(cls, snapshot_name_batch: str) -> t.Iterator[QueryExecutionC
7067
Context manager for tracking snapshot evaluation execution statistics.
7168
"""
7269
context = QueryExecutionContext(id=snapshot_name_batch)
73-
thread_id = get_ident()
74-
75-
with cls._contexts_lock:
76-
cls._thread_contexts[thread_id] = context
70+
cls._thread_local.context = context
7771
try:
7872
yield context
7973
finally:
80-
with cls._contexts_lock:
81-
cls._thread_contexts.pop(thread_id, None)
74+
if hasattr(cls._thread_local, "context"):
75+
delattr(cls._thread_local, "context")
8276

8377
@classmethod
8478
def record_execution(cls, sql: str, row_count: t.Optional[int]) -> None:
85-
thread_id = get_ident()
86-
with cls._contexts_lock:
87-
context = cls._thread_contexts.get(thread_id)
88-
if context is not None:
89-
context.add_execution(sql, row_count)
79+
context = cls.get_execution_context()
80+
if context is not None:
81+
context.add_execution(sql, row_count)
9082

9183
@classmethod
9284
def get_execution_stats(cls) -> t.Optional[t.Dict[str, t.Any]]:
@@ -96,57 +88,42 @@ def get_execution_stats(cls) -> t.Optional[t.Dict[str, t.Any]]:
9688

9789
class SeedExecutionTracker:
9890
_seed_contexts: t.Dict[str, QueryExecutionContext] = {}
99-
_active_threads: t.Set[int] = set()
100-
_thread_to_seed_id: t.Dict[int, str] = {}
101-
_seed_contexts_lock = Lock()
91+
_thread_local = local()
10292

10393
@classmethod
10494
@contextmanager
10595
def track_execution(cls, model_name: str) -> t.Iterator[QueryExecutionContext]:
10696
"""
10797
Context manager for tracking seed creation execution statistics.
10898
"""
109-
11099
context = QueryExecutionContext(id=model_name)
111-
thread_id = get_ident()
112-
113-
with cls._seed_contexts_lock:
114-
cls._seed_contexts[model_name] = context
115-
cls._active_threads.add(thread_id)
116-
cls._thread_to_seed_id[thread_id] = model_name
100+
cls._seed_contexts[model_name] = context
101+
cls._thread_local.seed_id = model_name
117102

118103
try:
119104
yield context
120105
finally:
121-
with cls._seed_contexts_lock:
122-
cls._active_threads.discard(thread_id)
123-
cls._thread_to_seed_id.pop(thread_id, None)
106+
if hasattr(cls._thread_local, "seed_id"):
107+
delattr(cls._thread_local, "seed_id")
124108

125109
@classmethod
126110
def get_and_clear_seed_stats(cls, model_name: str) -> t.Optional[t.Dict[str, t.Any]]:
127-
with cls._seed_contexts_lock:
128-
context = cls._seed_contexts.pop(model_name, None)
129-
return context.get_execution_stats() if context else None
111+
context = cls._seed_contexts.pop(model_name, None)
112+
return context.get_execution_stats() if context else None
130113

131114
@classmethod
132115
def clear_all_seed_stats(cls) -> None:
133116
"""Clear all remaining seed stats. Used for cleanup after evaluation completes."""
134-
with cls._seed_contexts_lock:
135-
cls._seed_contexts.clear()
117+
cls._seed_contexts.clear()
136118

137119
@classmethod
138120
def is_tracking(cls) -> bool:
139-
thread_id = get_ident()
140-
with cls._seed_contexts_lock:
141-
return thread_id in cls._active_threads
121+
return hasattr(cls._thread_local, "seed_id")
142122

143123
@classmethod
144124
def record_execution(cls, sql: str, row_count: t.Optional[int]) -> None:
145-
thread_id = get_ident()
146-
with cls._seed_contexts_lock:
147-
seed_id = cls._thread_to_seed_id.get(thread_id)
148-
if not seed_id:
149-
return
125+
seed_id = getattr(cls._thread_local, "seed_id", None)
126+
if seed_id:
150127
context = cls._seed_contexts.get(seed_id)
151128
if context is not None:
152129
context.add_execution(sql, row_count)

0 commit comments

Comments
 (0)