-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathconcurrency.py
143 lines (122 loc) · 4.35 KB
/
concurrency.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from __future__ import annotations
import contextlib
import logging
import threading
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
NamedTuple,
Protocol,
Sequence,
)
if TYPE_CHECKING:
from types import TracebackType
from cassandra.cluster import ResponseFuture, Session
from cassandra.query import PreparedStatement, SimpleStatement
logger = logging.getLogger(__name__)
class _Callback(Protocol):
def __call__(self, rows: Sequence[Any], /) -> None:
...
class ConcurrentQueries(contextlib.AbstractContextManager["ConcurrentQueries"]):
"""Context manager for concurrent queries with a max limit of 5 ongoing queries."""
_MAX_CONCURRENT_QUERIES = 5
def __init__(self, session: Session) -> None:
self._session = session
self._completion = threading.Condition()
self._pending = 0
self._error: BaseException | None = None
self._semaphore = threading.Semaphore(self._MAX_CONCURRENT_QUERIES)
def _handle_result(
self,
result: Sequence[NamedTuple],
future: ResponseFuture,
callback: Callable[[Sequence[NamedTuple]], Any] | None,
) -> None:
if callback is not None:
callback(result)
if future.has_more_pages:
future.start_fetching_next_page()
else:
with self._completion:
self._pending -= 1
self._semaphore.release() # Release the semaphore once a query completes
if self._pending == 0:
self._completion.notify()
def _handle_error(self, error: BaseException, future: ResponseFuture) -> None:
logger.error(
"Error executing query: %s",
future.query,
exc_info=error,
)
with self._completion:
self._error = error
self._pending -= 1 # Decrement pending count
self._semaphore.release() # Release the semaphore on error
self._completion.notify()
def execute(
self,
query: PreparedStatement | SimpleStatement,
parameters: tuple[Any, ...] | None = None,
callback: _Callback | None = None,
timeout: float | None = None,
) -> None:
"""Execute a query concurrently with a max of 5 concurrent queries.
Args:
query: The query to execute.
parameters: Parameter tuple for the query. Defaults to `None`.
callback: Callback to apply to the results. Defaults to `None`.
timeout: Timeout to use (if not the session default).
"""
with self._completion:
if self._error is not None:
return
# Acquire the semaphore before proceeding to ensure we do not exceed the max limit
self._semaphore.acquire()
with self._completion:
if self._error is not None:
# Release semaphore before returning
self._semaphore.release()
return
self._pending += 1
try:
execute_kwargs = {}
if timeout is not None:
execute_kwargs["timeout"] = timeout
future: ResponseFuture = self._session.execute_async(
query,
parameters,
**execute_kwargs,
)
future.add_callbacks(
self._handle_result,
self._handle_error,
callback_kwargs={
"future": future,
"callback": callback,
},
errback_kwargs={
"future": future,
},
)
except Exception as e:
with self._completion:
self._error = e
self._pending -= 1 # Decrement pending count
self._semaphore.release() # Release semaphore
self._completion.notify()
raise
def __exit__(
self,
_exc_type: type[BaseException] | None,
_exc_inst: BaseException | None,
_exc_traceback: TracebackType | None,
) -> Literal[False]:
with self._completion:
while self._error is None and self._pending > 0:
self._completion.wait()
if self._error is not None:
raise self._error
# Don't swallow the exception.
return False