Skip to content

Commit 4991e07

Browse files
authored
feat(concurrent perpartition cursor): Add parent state updates (#343)
1 parent c3efa4c commit 4991e07

File tree

2 files changed

+516
-17
lines changed

2 files changed

+516
-17
lines changed

airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py

Lines changed: 83 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ def __init__(
9595
# the oldest partitions can be efficiently removed, maintaining the most recent partitions.
9696
self._cursor_per_partition: OrderedDict[str, ConcurrentCursor] = OrderedDict()
9797
self._semaphore_per_partition: OrderedDict[str, threading.Semaphore] = OrderedDict()
98+
99+
# Parent-state tracking: store each partition’s parent state in creation order
100+
self._partition_parent_state_map: OrderedDict[str, Mapping[str, Any]] = OrderedDict()
101+
98102
self._finished_partitions: set[str] = set()
99103
self._lock = threading.Lock()
100104
self._timer = Timer()
@@ -155,11 +159,62 @@ def close_partition(self, partition: Partition) -> None:
155159
and self._semaphore_per_partition[partition_key]._value == 0
156160
):
157161
self._update_global_cursor(cursor.state[self.cursor_field.cursor_field_key])
158-
self._emit_state_message()
162+
163+
self._check_and_update_parent_state()
164+
165+
self._emit_state_message()
166+
167+
def _check_and_update_parent_state(self) -> None:
168+
"""
169+
Pop the leftmost partition state from _partition_parent_state_map only if
170+
*all partitions* up to (and including) that partition key in _semaphore_per_partition
171+
are fully finished (i.e. in _finished_partitions and semaphore._value == 0).
172+
Additionally, delete finished semaphores with a value of 0 to free up memory,
173+
as they are only needed to track errors and completion status.
174+
"""
175+
last_closed_state = None
176+
177+
while self._partition_parent_state_map:
178+
# Look at the earliest partition key in creation order
179+
earliest_key = next(iter(self._partition_parent_state_map))
180+
181+
# Verify ALL partitions from the left up to earliest_key are finished
182+
all_left_finished = True
183+
for p_key, sem in list(
184+
self._semaphore_per_partition.items()
185+
): # Use list to allow modification during iteration
186+
# If any earlier partition is still not finished, we must stop
187+
if p_key not in self._finished_partitions or sem._value != 0:
188+
all_left_finished = False
189+
break
190+
# Once we've reached earliest_key in the semaphore order, we can stop checking
191+
if p_key == earliest_key:
192+
break
193+
194+
# If the partitions up to earliest_key are not all finished, break the while-loop
195+
if not all_left_finished:
196+
break
197+
198+
# Pop the leftmost entry from parent-state map
199+
_, closed_parent_state = self._partition_parent_state_map.popitem(last=False)
200+
last_closed_state = closed_parent_state
201+
202+
# Clean up finished semaphores with value 0 up to and including earliest_key
203+
for p_key in list(self._semaphore_per_partition.keys()):
204+
sem = self._semaphore_per_partition[p_key]
205+
if p_key in self._finished_partitions and sem._value == 0:
206+
del self._semaphore_per_partition[p_key]
207+
logger.debug(f"Deleted finished semaphore for partition {p_key} with value 0")
208+
if p_key == earliest_key:
209+
break
210+
211+
# Update _parent_state if we popped at least one partition
212+
if last_closed_state is not None:
213+
self._parent_state = last_closed_state
159214

160215
def ensure_at_least_one_state_emitted(self) -> None:
161216
"""
162-
The platform expect to have at least one state message on successful syncs. Hence, whatever happens, we expect this method to be
217+
The platform expects at least one state message on successful syncs. Hence, whatever happens, we expect this method to be
163218
called.
164219
"""
165220
if not any(
@@ -201,13 +256,19 @@ def stream_slices(self) -> Iterable[StreamSlice]:
201256

202257
slices = self._partition_router.stream_slices()
203258
self._timer.start()
204-
for partition in slices:
205-
yield from self._generate_slices_from_partition(partition)
259+
for partition, last, parent_state in iterate_with_last_flag_and_state(
260+
slices, self._partition_router.get_stream_state
261+
):
262+
yield from self._generate_slices_from_partition(partition, parent_state)
206263

207-
def _generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[StreamSlice]:
264+
def _generate_slices_from_partition(
265+
self, partition: StreamSlice, parent_state: Mapping[str, Any]
266+
) -> Iterable[StreamSlice]:
208267
# Ensure the maximum number of partitions is not exceeded
209268
self._ensure_partition_limit()
210269

270+
partition_key = self._to_partition_key(partition.partition)
271+
211272
cursor = self._cursor_per_partition.get(self._to_partition_key(partition.partition))
212273
if not cursor:
213274
cursor = self._create_cursor(
@@ -216,18 +277,26 @@ def _generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[St
216277
)
217278
with self._lock:
218279
self._number_of_partitions += 1
219-
self._cursor_per_partition[self._to_partition_key(partition.partition)] = cursor
220-
self._semaphore_per_partition[self._to_partition_key(partition.partition)] = (
221-
threading.Semaphore(0)
222-
)
280+
self._cursor_per_partition[partition_key] = cursor
281+
self._semaphore_per_partition[partition_key] = threading.Semaphore(0)
282+
283+
with self._lock:
284+
if (
285+
len(self._partition_parent_state_map) == 0
286+
or self._partition_parent_state_map[
287+
next(reversed(self._partition_parent_state_map))
288+
]
289+
!= parent_state
290+
):
291+
self._partition_parent_state_map[partition_key] = deepcopy(parent_state)
223292

224293
for cursor_slice, is_last_slice, _ in iterate_with_last_flag_and_state(
225294
cursor.stream_slices(),
226295
lambda: None,
227296
):
228-
self._semaphore_per_partition[self._to_partition_key(partition.partition)].release()
297+
self._semaphore_per_partition[partition_key].release()
229298
if is_last_slice:
230-
self._finished_partitions.add(self._to_partition_key(partition.partition))
299+
self._finished_partitions.add(partition_key)
231300
yield StreamSlice(
232301
partition=partition, cursor_slice=cursor_slice, extra_fields=partition.extra_fields
233302
)
@@ -257,9 +326,9 @@ def _ensure_partition_limit(self) -> None:
257326
while len(self._cursor_per_partition) > self.DEFAULT_MAX_PARTITIONS_NUMBER - 1:
258327
# Try removing finished partitions first
259328
for partition_key in list(self._cursor_per_partition.keys()):
260-
if (
261-
partition_key in self._finished_partitions
262-
and self._semaphore_per_partition[partition_key]._value == 0
329+
if partition_key in self._finished_partitions and (
330+
partition_key not in self._semaphore_per_partition
331+
or self._semaphore_per_partition[partition_key]._value == 0
263332
):
264333
oldest_partition = self._cursor_per_partition.pop(
265334
partition_key
@@ -338,9 +407,6 @@ def _set_initial_state(self, stream_state: StreamState) -> None:
338407
self._cursor_per_partition[self._to_partition_key(state["partition"])] = (
339408
self._create_cursor(state["cursor"])
340409
)
341-
self._semaphore_per_partition[self._to_partition_key(state["partition"])] = (
342-
threading.Semaphore(0)
343-
)
344410

345411
# set default state for missing partitions if it is per partition with fallback to global
346412
if self._GLOBAL_STATE_KEY in stream_state:

0 commit comments

Comments
 (0)