|
9 | 9 | from collections import OrderedDict
|
10 | 10 | from copy import deepcopy
|
11 | 11 | from datetime import timedelta
|
12 |
| -from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional |
| 12 | +from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional |
13 | 13 |
|
14 | 14 | from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
|
15 | 15 | from airbyte_cdk.sources.declarative.incremental.global_substream_cursor import (
|
@@ -66,8 +66,8 @@ class ConcurrentPerPartitionCursor(Cursor):
|
66 | 66 | _GLOBAL_STATE_KEY = "state"
|
67 | 67 | _PERPARTITION_STATE_KEY = "states"
|
68 | 68 | _IS_PARTITION_DUPLICATION_LOGGED = False
|
69 |
| - _KEY = 0 |
70 |
| - _VALUE = 1 |
| 69 | + _PARENT_STATE = 0 |
| 70 | + _GENERATION_SEQUENCE = 1 |
71 | 71 |
|
72 | 72 | def __init__(
|
73 | 73 | self,
|
@@ -99,19 +99,29 @@ def __init__(
|
99 | 99 | self._semaphore_per_partition: OrderedDict[str, threading.Semaphore] = OrderedDict()
|
100 | 100 |
|
101 | 101 | # Parent-state tracking: store each partition’s parent state in creation order
|
102 |
| - self._partition_parent_state_map: OrderedDict[str, Mapping[str, Any]] = OrderedDict() |
| 102 | + self._partition_parent_state_map: OrderedDict[str, tuple[Mapping[str, Any], int]] = ( |
| 103 | + OrderedDict() |
| 104 | + ) |
| 105 | + self._parent_state: Optional[StreamState] = None |
| 106 | + |
| 107 | + # Tracks when the last slice for partition is emitted |
| 108 | + self._partitions_done_generating_stream_slices: set[str] = set() |
| 109 | + # Used to track the index of partitions that are not closed yet |
| 110 | + self._processing_partitions_indexes: List[int] = list() |
| 111 | + self._generated_partitions_count: int = 0 |
| 112 | + # Dictionary to map partition keys to their index |
| 113 | + self._partition_key_to_index: dict[str, int] = {} |
103 | 114 |
|
104 |
| - self._finished_partitions: set[str] = set() |
105 | 115 | self._lock = threading.Lock()
|
106 |
| - self._timer = Timer() |
107 |
| - self._new_global_cursor: Optional[StreamState] = None |
108 | 116 | self._lookback_window: int = 0
|
109 |
| - self._parent_state: Optional[StreamState] = None |
| 117 | + self._new_global_cursor: Optional[StreamState] = None |
110 | 118 | self._number_of_partitions: int = 0
|
111 | 119 | self._use_global_cursor: bool = use_global_cursor
|
112 | 120 | self._partition_serializer = PerPartitionKeySerializer()
|
| 121 | + |
113 | 122 | # Track the last time a state message was emitted
|
114 | 123 | self._last_emission_time: float = 0.0
|
| 124 | + self._timer = Timer() |
115 | 125 |
|
116 | 126 | self._set_initial_state(stream_state)
|
117 | 127 |
|
@@ -157,60 +167,37 @@ def close_partition(self, partition: Partition) -> None:
|
157 | 167 | self._cursor_per_partition[partition_key].close_partition(partition=partition)
|
158 | 168 | cursor = self._cursor_per_partition[partition_key]
|
159 | 169 | if (
|
160 |
| - partition_key in self._finished_partitions |
| 170 | + partition_key in self._partitions_done_generating_stream_slices |
161 | 171 | and self._semaphore_per_partition[partition_key]._value == 0
|
162 | 172 | ):
|
163 | 173 | self._update_global_cursor(cursor.state[self.cursor_field.cursor_field_key])
|
164 | 174 |
|
| 175 | + # Clean up the partition if it is fully processed |
| 176 | + self._cleanup_if_done(partition_key) |
| 177 | + |
165 | 178 | self._check_and_update_parent_state()
|
166 | 179 |
|
167 | 180 | self._emit_state_message()
|
168 | 181 |
|
169 | 182 | def _check_and_update_parent_state(self) -> None:
|
170 |
| - """ |
171 |
| - Pop the leftmost partition state from _partition_parent_state_map only if |
172 |
| - *all partitions* up to (and including) that partition key in _semaphore_per_partition |
173 |
| - are fully finished (i.e. in _finished_partitions and semaphore._value == 0). |
174 |
| - Additionally, delete finished semaphores with a value of 0 to free up memory, |
175 |
| - as they are only needed to track errors and completion status. |
176 |
| - """ |
177 | 183 | last_closed_state = None
|
178 | 184 |
|
179 | 185 | while self._partition_parent_state_map:
|
180 |
| - # Look at the earliest partition key in creation order |
181 |
| - earliest_key = next(iter(self._partition_parent_state_map)) |
182 |
| - |
183 |
| - # Verify ALL partitions from the left up to earliest_key are finished |
184 |
| - all_left_finished = True |
185 |
| - for p_key, sem in list( |
186 |
| - self._semaphore_per_partition.items() |
187 |
| - ): # Use list to allow modification during iteration |
188 |
| - # If any earlier partition is still not finished, we must stop |
189 |
| - if p_key not in self._finished_partitions or sem._value != 0: |
190 |
| - all_left_finished = False |
191 |
| - break |
192 |
| - # Once we've reached earliest_key in the semaphore order, we can stop checking |
193 |
| - if p_key == earliest_key: |
194 |
| - break |
195 |
| - |
196 |
| - # If the partitions up to earliest_key are not all finished, break the while-loop |
197 |
| - if not all_left_finished: |
198 |
| - break |
| 186 | + earliest_key, (candidate_state, candidate_seq) = next( |
| 187 | + iter(self._partition_parent_state_map.items()) |
| 188 | + ) |
199 | 189 |
|
200 |
| - # Pop the leftmost entry from parent-state map |
201 |
| - _, closed_parent_state = self._partition_parent_state_map.popitem(last=False) |
202 |
| - last_closed_state = closed_parent_state |
| 190 | + # if any partition that started <= candidate_seq is still open, we must wait |
| 191 | + if ( |
| 192 | + self._processing_partitions_indexes |
| 193 | + and self._processing_partitions_indexes[0] <= candidate_seq |
| 194 | + ): |
| 195 | + break |
203 | 196 |
|
204 |
| - # Clean up finished semaphores with value 0 up to and including earliest_key |
205 |
| - for p_key in list(self._semaphore_per_partition.keys()): |
206 |
| - sem = self._semaphore_per_partition[p_key] |
207 |
| - if p_key in self._finished_partitions and sem._value == 0: |
208 |
| - del self._semaphore_per_partition[p_key] |
209 |
| - logger.debug(f"Deleted finished semaphore for partition {p_key} with value 0") |
210 |
| - if p_key == earliest_key: |
211 |
| - break |
| 197 | + # safe to pop |
| 198 | + self._partition_parent_state_map.popitem(last=False) |
| 199 | + last_closed_state = candidate_state |
212 | 200 |
|
213 |
| - # Update _parent_state if we popped at least one partition |
214 | 201 | if last_closed_state is not None:
|
215 | 202 | self._parent_state = last_closed_state
|
216 | 203 |
|
@@ -289,26 +276,32 @@ def _generate_slices_from_partition(
|
289 | 276 | if not self._IS_PARTITION_DUPLICATION_LOGGED:
|
290 | 277 | logger.warning(f"Partition duplication detected for stream {self._stream_name}")
|
291 | 278 | self._IS_PARTITION_DUPLICATION_LOGGED = True
|
| 279 | + return |
292 | 280 | else:
|
293 | 281 | self._semaphore_per_partition[partition_key] = threading.Semaphore(0)
|
294 | 282 |
|
295 | 283 | with self._lock:
|
| 284 | + seq = self._generated_partitions_count |
| 285 | + self._generated_partitions_count += 1 |
| 286 | + self._processing_partitions_indexes.append(seq) |
| 287 | + self._partition_key_to_index[partition_key] = seq |
| 288 | + |
296 | 289 | if (
|
297 | 290 | len(self._partition_parent_state_map) == 0
|
298 | 291 | or self._partition_parent_state_map[
|
299 | 292 | next(reversed(self._partition_parent_state_map))
|
300 |
| - ] |
| 293 | + ][self._PARENT_STATE] |
301 | 294 | != parent_state
|
302 | 295 | ):
|
303 |
| - self._partition_parent_state_map[partition_key] = deepcopy(parent_state) |
| 296 | + self._partition_parent_state_map[partition_key] = (deepcopy(parent_state), seq) |
304 | 297 |
|
305 | 298 | for cursor_slice, is_last_slice, _ in iterate_with_last_flag_and_state(
|
306 | 299 | cursor.stream_slices(),
|
307 | 300 | lambda: None,
|
308 | 301 | ):
|
309 | 302 | self._semaphore_per_partition[partition_key].release()
|
310 | 303 | if is_last_slice:
|
311 |
| - self._finished_partitions.add(partition_key) |
| 304 | + self._partitions_done_generating_stream_slices.add(partition_key) |
312 | 305 | yield StreamSlice(
|
313 | 306 | partition=partition, cursor_slice=cursor_slice, extra_fields=partition.extra_fields
|
314 | 307 | )
|
@@ -338,14 +331,11 @@ def _ensure_partition_limit(self) -> None:
|
338 | 331 | while len(self._cursor_per_partition) > self.DEFAULT_MAX_PARTITIONS_NUMBER - 1:
|
339 | 332 | # Try removing finished partitions first
|
340 | 333 | for partition_key in list(self._cursor_per_partition.keys()):
|
341 |
| - if partition_key in self._finished_partitions and ( |
342 |
| - partition_key not in self._semaphore_per_partition |
343 |
| - or self._semaphore_per_partition[partition_key]._value == 0 |
344 |
| - ): |
| 334 | + if partition_key not in self._partition_key_to_index: |
345 | 335 | oldest_partition = self._cursor_per_partition.pop(
|
346 | 336 | partition_key
|
347 | 337 | ) # Remove the oldest partition
|
348 |
| - logger.warning( |
| 338 | + logger.debug( |
349 | 339 | f"The maximum number of partitions has been reached. Dropping the oldest finished partition: {oldest_partition}. Over limit: {self._number_of_partitions - self.DEFAULT_MAX_PARTITIONS_NUMBER}."
|
350 | 340 | )
|
351 | 341 | break
|
@@ -474,6 +464,25 @@ def _update_global_cursor(self, value: Any) -> None:
|
474 | 464 | ):
|
475 | 465 | self._new_global_cursor = {self.cursor_field.cursor_field_key: copy.deepcopy(value)}
|
476 | 466 |
|
| 467 | + def _cleanup_if_done(self, partition_key: str) -> None: |
| 468 | + """ |
| 469 | + Free every in-memory structure that belonged to a completed partition: |
| 470 | + cursor, semaphore, flag inside `_finished_partitions` |
| 471 | + """ |
| 472 | + if not ( |
| 473 | + partition_key in self._partitions_done_generating_stream_slices |
| 474 | + and self._semaphore_per_partition[partition_key]._value == 0 |
| 475 | + ): |
| 476 | + return |
| 477 | + |
| 478 | + self._semaphore_per_partition.pop(partition_key, None) |
| 479 | + self._partitions_done_generating_stream_slices.discard(partition_key) |
| 480 | + |
| 481 | + seq = self._partition_key_to_index.pop(partition_key) |
| 482 | + self._processing_partitions_indexes.remove(seq) |
| 483 | + |
| 484 | + logger.debug(f"Partition {partition_key} fully processed and cleaned up.") |
| 485 | + |
477 | 486 | def _to_partition_key(self, partition: Mapping[str, Any]) -> str:
|
478 | 487 | return self._partition_serializer.to_partition_key(partition)
|
479 | 488 |
|
|
0 commit comments