Skip to content

Commit 795a896

Browse files
authored
fix(concurrent-perpartition-cursor): Fix memory issues (#568)
1 parent 4be29b6 commit 795a896

File tree

2 files changed

+279
-62
lines changed

2 files changed

+279
-62
lines changed

airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py

Lines changed: 64 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from collections import OrderedDict
1010
from copy import deepcopy
1111
from datetime import timedelta
12-
from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional
12+
from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional
1313

1414
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
1515
from airbyte_cdk.sources.declarative.incremental.global_substream_cursor import (
@@ -66,8 +66,8 @@ class ConcurrentPerPartitionCursor(Cursor):
6666
_GLOBAL_STATE_KEY = "state"
6767
_PERPARTITION_STATE_KEY = "states"
6868
_IS_PARTITION_DUPLICATION_LOGGED = False
69-
_KEY = 0
70-
_VALUE = 1
69+
_PARENT_STATE = 0
70+
_GENERATION_SEQUENCE = 1
7171

7272
def __init__(
7373
self,
@@ -99,19 +99,29 @@ def __init__(
9999
self._semaphore_per_partition: OrderedDict[str, threading.Semaphore] = OrderedDict()
100100

101101
# Parent-state tracking: store each partition’s parent state in creation order
102-
self._partition_parent_state_map: OrderedDict[str, Mapping[str, Any]] = OrderedDict()
102+
self._partition_parent_state_map: OrderedDict[str, tuple[Mapping[str, Any], int]] = (
103+
OrderedDict()
104+
)
105+
self._parent_state: Optional[StreamState] = None
106+
107+
# Tracks when the last slice for partition is emitted
108+
self._partitions_done_generating_stream_slices: set[str] = set()
109+
# Used to track the index of partitions that are not closed yet
110+
self._processing_partitions_indexes: List[int] = list()
111+
self._generated_partitions_count: int = 0
112+
# Dictionary to map partition keys to their index
113+
self._partition_key_to_index: dict[str, int] = {}
103114

104-
self._finished_partitions: set[str] = set()
105115
self._lock = threading.Lock()
106-
self._timer = Timer()
107-
self._new_global_cursor: Optional[StreamState] = None
108116
self._lookback_window: int = 0
109-
self._parent_state: Optional[StreamState] = None
117+
self._new_global_cursor: Optional[StreamState] = None
110118
self._number_of_partitions: int = 0
111119
self._use_global_cursor: bool = use_global_cursor
112120
self._partition_serializer = PerPartitionKeySerializer()
121+
113122
# Track the last time a state message was emitted
114123
self._last_emission_time: float = 0.0
124+
self._timer = Timer()
115125

116126
self._set_initial_state(stream_state)
117127

@@ -157,60 +167,37 @@ def close_partition(self, partition: Partition) -> None:
157167
self._cursor_per_partition[partition_key].close_partition(partition=partition)
158168
cursor = self._cursor_per_partition[partition_key]
159169
if (
160-
partition_key in self._finished_partitions
170+
partition_key in self._partitions_done_generating_stream_slices
161171
and self._semaphore_per_partition[partition_key]._value == 0
162172
):
163173
self._update_global_cursor(cursor.state[self.cursor_field.cursor_field_key])
164174

175+
# Clean up the partition if it is fully processed
176+
self._cleanup_if_done(partition_key)
177+
165178
self._check_and_update_parent_state()
166179

167180
self._emit_state_message()
168181

169182
def _check_and_update_parent_state(self) -> None:
170-
"""
171-
Pop the leftmost partition state from _partition_parent_state_map only if
172-
*all partitions* up to (and including) that partition key in _semaphore_per_partition
173-
are fully finished (i.e. in _finished_partitions and semaphore._value == 0).
174-
Additionally, delete finished semaphores with a value of 0 to free up memory,
175-
as they are only needed to track errors and completion status.
176-
"""
177183
last_closed_state = None
178184

179185
while self._partition_parent_state_map:
180-
# Look at the earliest partition key in creation order
181-
earliest_key = next(iter(self._partition_parent_state_map))
182-
183-
# Verify ALL partitions from the left up to earliest_key are finished
184-
all_left_finished = True
185-
for p_key, sem in list(
186-
self._semaphore_per_partition.items()
187-
): # Use list to allow modification during iteration
188-
# If any earlier partition is still not finished, we must stop
189-
if p_key not in self._finished_partitions or sem._value != 0:
190-
all_left_finished = False
191-
break
192-
# Once we've reached earliest_key in the semaphore order, we can stop checking
193-
if p_key == earliest_key:
194-
break
195-
196-
# If the partitions up to earliest_key are not all finished, break the while-loop
197-
if not all_left_finished:
198-
break
186+
earliest_key, (candidate_state, candidate_seq) = next(
187+
iter(self._partition_parent_state_map.items())
188+
)
199189

200-
# Pop the leftmost entry from parent-state map
201-
_, closed_parent_state = self._partition_parent_state_map.popitem(last=False)
202-
last_closed_state = closed_parent_state
190+
# if any partition that started <= candidate_seq is still open, we must wait
191+
if (
192+
self._processing_partitions_indexes
193+
and self._processing_partitions_indexes[0] <= candidate_seq
194+
):
195+
break
203196

204-
# Clean up finished semaphores with value 0 up to and including earliest_key
205-
for p_key in list(self._semaphore_per_partition.keys()):
206-
sem = self._semaphore_per_partition[p_key]
207-
if p_key in self._finished_partitions and sem._value == 0:
208-
del self._semaphore_per_partition[p_key]
209-
logger.debug(f"Deleted finished semaphore for partition {p_key} with value 0")
210-
if p_key == earliest_key:
211-
break
197+
# safe to pop
198+
self._partition_parent_state_map.popitem(last=False)
199+
last_closed_state = candidate_state
212200

213-
# Update _parent_state if we popped at least one partition
214201
if last_closed_state is not None:
215202
self._parent_state = last_closed_state
216203

@@ -289,26 +276,32 @@ def _generate_slices_from_partition(
289276
if not self._IS_PARTITION_DUPLICATION_LOGGED:
290277
logger.warning(f"Partition duplication detected for stream {self._stream_name}")
291278
self._IS_PARTITION_DUPLICATION_LOGGED = True
279+
return
292280
else:
293281
self._semaphore_per_partition[partition_key] = threading.Semaphore(0)
294282

295283
with self._lock:
284+
seq = self._generated_partitions_count
285+
self._generated_partitions_count += 1
286+
self._processing_partitions_indexes.append(seq)
287+
self._partition_key_to_index[partition_key] = seq
288+
296289
if (
297290
len(self._partition_parent_state_map) == 0
298291
or self._partition_parent_state_map[
299292
next(reversed(self._partition_parent_state_map))
300-
]
293+
][self._PARENT_STATE]
301294
!= parent_state
302295
):
303-
self._partition_parent_state_map[partition_key] = deepcopy(parent_state)
296+
self._partition_parent_state_map[partition_key] = (deepcopy(parent_state), seq)
304297

305298
for cursor_slice, is_last_slice, _ in iterate_with_last_flag_and_state(
306299
cursor.stream_slices(),
307300
lambda: None,
308301
):
309302
self._semaphore_per_partition[partition_key].release()
310303
if is_last_slice:
311-
self._finished_partitions.add(partition_key)
304+
self._partitions_done_generating_stream_slices.add(partition_key)
312305
yield StreamSlice(
313306
partition=partition, cursor_slice=cursor_slice, extra_fields=partition.extra_fields
314307
)
@@ -338,14 +331,11 @@ def _ensure_partition_limit(self) -> None:
338331
while len(self._cursor_per_partition) > self.DEFAULT_MAX_PARTITIONS_NUMBER - 1:
339332
# Try removing finished partitions first
340333
for partition_key in list(self._cursor_per_partition.keys()):
341-
if partition_key in self._finished_partitions and (
342-
partition_key not in self._semaphore_per_partition
343-
or self._semaphore_per_partition[partition_key]._value == 0
344-
):
334+
if partition_key not in self._partition_key_to_index:
345335
oldest_partition = self._cursor_per_partition.pop(
346336
partition_key
347337
) # Remove the oldest partition
348-
logger.warning(
338+
logger.debug(
349339
f"The maximum number of partitions has been reached. Dropping the oldest finished partition: {oldest_partition}. Over limit: {self._number_of_partitions - self.DEFAULT_MAX_PARTITIONS_NUMBER}."
350340
)
351341
break
@@ -474,6 +464,25 @@ def _update_global_cursor(self, value: Any) -> None:
474464
):
475465
self._new_global_cursor = {self.cursor_field.cursor_field_key: copy.deepcopy(value)}
476466

467+
def _cleanup_if_done(self, partition_key: str) -> None:
468+
"""
469+
Free every in-memory structure that belonged to a completed partition:
470+
cursor, semaphore, flag inside `_finished_partitions`
471+
"""
472+
if not (
473+
partition_key in self._partitions_done_generating_stream_slices
474+
and self._semaphore_per_partition[partition_key]._value == 0
475+
):
476+
return
477+
478+
self._semaphore_per_partition.pop(partition_key, None)
479+
self._partitions_done_generating_stream_slices.discard(partition_key)
480+
481+
seq = self._partition_key_to_index.pop(partition_key)
482+
self._processing_partitions_indexes.remove(seq)
483+
484+
logger.debug(f"Partition {partition_key} fully processed and cleaned up.")
485+
477486
def _to_partition_key(self, partition: Mapping[str, Any]) -> str:
478487
return self._partition_serializer.to_partition_key(partition)
479488

0 commit comments

Comments
 (0)