diff --git a/airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py b/airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py index 0aae6be8a..a2194c75a 100644 --- a/airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py +++ b/airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py @@ -9,7 +9,7 @@ from collections import OrderedDict from copy import deepcopy from datetime import timedelta -from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional +from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager from airbyte_cdk.sources.declarative.incremental.global_substream_cursor import ( @@ -66,8 +66,8 @@ class ConcurrentPerPartitionCursor(Cursor): _GLOBAL_STATE_KEY = "state" _PERPARTITION_STATE_KEY = "states" _IS_PARTITION_DUPLICATION_LOGGED = False - _KEY = 0 - _VALUE = 1 + _PARENT_STATE = 0 + _GENERATION_SEQUENCE = 1 def __init__( self, @@ -99,19 +99,29 @@ def __init__( self._semaphore_per_partition: OrderedDict[str, threading.Semaphore] = OrderedDict() # Parent-state tracking: store each partition’s parent state in creation order - self._partition_parent_state_map: OrderedDict[str, Mapping[str, Any]] = OrderedDict() + self._partition_parent_state_map: OrderedDict[str, tuple[Mapping[str, Any], int]] = ( + OrderedDict() + ) + self._parent_state: Optional[StreamState] = None + + # Tracks when the last slice for partition is emitted + self._partitions_done_generating_stream_slices: set[str] = set() + # Used to track the index of partitions that are not closed yet + self._processing_partitions_indexes: List[int] = list() + self._generated_partitions_count: int = 0 + # Dictionary to map partition keys to their index + self._partition_key_to_index: dict[str, int] = {} - self._finished_partitions: set[str] = set() self._lock = threading.Lock() - self._timer = Timer() - self._new_global_cursor: Optional[StreamState] = None self._lookback_window: int = 0 - self._parent_state: Optional[StreamState] = None + self._new_global_cursor: Optional[StreamState] = None self._number_of_partitions: int = 0 self._use_global_cursor: bool = use_global_cursor self._partition_serializer = PerPartitionKeySerializer() + # Track the last time a state message was emitted self._last_emission_time: float = 0.0 + self._timer = Timer() self._set_initial_state(stream_state) @@ -157,60 +167,37 @@ def close_partition(self, partition: Partition) -> None: self._cursor_per_partition[partition_key].close_partition(partition=partition) cursor = self._cursor_per_partition[partition_key] if ( - partition_key in self._finished_partitions + partition_key in self._partitions_done_generating_stream_slices and self._semaphore_per_partition[partition_key]._value == 0 ): self._update_global_cursor(cursor.state[self.cursor_field.cursor_field_key]) + # Clean up the partition if it is fully processed + self._cleanup_if_done(partition_key) + self._check_and_update_parent_state() self._emit_state_message() def _check_and_update_parent_state(self) -> None: - """ - Pop the leftmost partition state from _partition_parent_state_map only if - *all partitions* up to (and including) that partition key in _semaphore_per_partition - are fully finished (i.e. in _finished_partitions and semaphore._value == 0). - Additionally, delete finished semaphores with a value of 0 to free up memory, - as they are only needed to track errors and completion status. - """ last_closed_state = None while self._partition_parent_state_map: - # Look at the earliest partition key in creation order - earliest_key = next(iter(self._partition_parent_state_map)) - - # Verify ALL partitions from the left up to earliest_key are finished - all_left_finished = True - for p_key, sem in list( - self._semaphore_per_partition.items() - ): # Use list to allow modification during iteration - # If any earlier partition is still not finished, we must stop - if p_key not in self._finished_partitions or sem._value != 0: - all_left_finished = False - break - # Once we've reached earliest_key in the semaphore order, we can stop checking - if p_key == earliest_key: - break - - # If the partitions up to earliest_key are not all finished, break the while-loop - if not all_left_finished: - break + earliest_key, (candidate_state, candidate_seq) = next( + iter(self._partition_parent_state_map.items()) + ) - # Pop the leftmost entry from parent-state map - _, closed_parent_state = self._partition_parent_state_map.popitem(last=False) - last_closed_state = closed_parent_state + # if any partition that started <= candidate_seq is still open, we must wait + if ( + self._processing_partitions_indexes + and self._processing_partitions_indexes[0] <= candidate_seq + ): + break - # Clean up finished semaphores with value 0 up to and including earliest_key - for p_key in list(self._semaphore_per_partition.keys()): - sem = self._semaphore_per_partition[p_key] - if p_key in self._finished_partitions and sem._value == 0: - del self._semaphore_per_partition[p_key] - logger.debug(f"Deleted finished semaphore for partition {p_key} with value 0") - if p_key == earliest_key: - break + # safe to pop + self._partition_parent_state_map.popitem(last=False) + last_closed_state = candidate_state - # Update _parent_state if we popped at least one partition if last_closed_state is not None: self._parent_state = last_closed_state @@ -289,18 +276,24 @@ def _generate_slices_from_partition( if not self._IS_PARTITION_DUPLICATION_LOGGED: logger.warning(f"Partition duplication detected for stream {self._stream_name}") self._IS_PARTITION_DUPLICATION_LOGGED = True + return else: self._semaphore_per_partition[partition_key] = threading.Semaphore(0) with self._lock: + seq = self._generated_partitions_count + self._generated_partitions_count += 1 + self._processing_partitions_indexes.append(seq) + self._partition_key_to_index[partition_key] = seq + if ( len(self._partition_parent_state_map) == 0 or self._partition_parent_state_map[ next(reversed(self._partition_parent_state_map)) - ] + ][self._PARENT_STATE] != parent_state ): - self._partition_parent_state_map[partition_key] = deepcopy(parent_state) + self._partition_parent_state_map[partition_key] = (deepcopy(parent_state), seq) for cursor_slice, is_last_slice, _ in iterate_with_last_flag_and_state( cursor.stream_slices(), @@ -308,7 +301,7 @@ def _generate_slices_from_partition( ): self._semaphore_per_partition[partition_key].release() if is_last_slice: - self._finished_partitions.add(partition_key) + self._partitions_done_generating_stream_slices.add(partition_key) yield StreamSlice( partition=partition, cursor_slice=cursor_slice, extra_fields=partition.extra_fields ) @@ -338,14 +331,11 @@ def _ensure_partition_limit(self) -> None: while len(self._cursor_per_partition) > self.DEFAULT_MAX_PARTITIONS_NUMBER - 1: # Try removing finished partitions first for partition_key in list(self._cursor_per_partition.keys()): - if partition_key in self._finished_partitions and ( - partition_key not in self._semaphore_per_partition - or self._semaphore_per_partition[partition_key]._value == 0 - ): + if partition_key not in self._partition_key_to_index: oldest_partition = self._cursor_per_partition.pop( partition_key ) # Remove the oldest partition - logger.warning( + logger.debug( f"The maximum number of partitions has been reached. Dropping the oldest finished partition: {oldest_partition}. Over limit: {self._number_of_partitions - self.DEFAULT_MAX_PARTITIONS_NUMBER}." ) break @@ -474,6 +464,25 @@ def _update_global_cursor(self, value: Any) -> None: ): self._new_global_cursor = {self.cursor_field.cursor_field_key: copy.deepcopy(value)} + def _cleanup_if_done(self, partition_key: str) -> None: + """ + Free every in-memory structure that belonged to a completed partition: + cursor, semaphore, flag inside `_finished_partitions` + """ + if not ( + partition_key in self._partitions_done_generating_stream_slices + and self._semaphore_per_partition[partition_key]._value == 0 + ): + return + + self._semaphore_per_partition.pop(partition_key, None) + self._partitions_done_generating_stream_slices.discard(partition_key) + + seq = self._partition_key_to_index.pop(partition_key) + self._processing_partitions_indexes.remove(seq) + + logger.debug(f"Partition {partition_key} fully processed and cleaned up.") + def _to_partition_key(self, partition: Mapping[str, Any]) -> str: return self._partition_serializer.to_partition_key(partition) diff --git a/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py b/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py index b54fc4779..ae6ec0713 100644 --- a/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py +++ b/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py @@ -3436,7 +3436,11 @@ def test_given_unfinished_first_parent_partition_no_parent_state_update(): } assert mock_cursor_1.stream_slices.call_count == 1 # Called once for each partition assert mock_cursor_2.stream_slices.call_count == 1 # Called once for each partition - assert len(cursor._semaphore_per_partition) == 2 + + assert len(cursor._semaphore_per_partition) == 1 + assert len(cursor._partitions_done_generating_stream_slices) == 1 + assert len(cursor._processing_partitions_indexes) == 1 + assert len(cursor._partition_key_to_index) == 1 def test_given_unfinished_last_parent_partition_with_partial_parent_state_update(): @@ -3520,7 +3524,11 @@ def test_given_unfinished_last_parent_partition_with_partial_parent_state_update } assert mock_cursor_1.stream_slices.call_count == 1 # Called once for each partition assert mock_cursor_2.stream_slices.call_count == 1 # Called once for each partition + assert len(cursor._semaphore_per_partition) == 1 + assert len(cursor._partitions_done_generating_stream_slices) == 1 + assert len(cursor._processing_partitions_indexes) == 1 + assert len(cursor._partition_key_to_index) == 1 def test_given_all_partitions_finished_when_close_partition_then_final_state_emitted(): @@ -3595,7 +3603,12 @@ def test_given_all_partitions_finished_when_close_partition_then_final_state_emi assert final_state["lookback_window"] == 1 assert cursor._message_repository.emit_message.call_count == 2 assert mock_cursor.stream_slices.call_count == 2 # Called once for each partition - assert len(cursor._semaphore_per_partition) == 1 + + # Checks that all internal variables are cleaned up + assert len(cursor._semaphore_per_partition) == 0 + assert len(cursor._partitions_done_generating_stream_slices) == 0 + assert len(cursor._processing_partitions_indexes) == 0 + assert len(cursor._partition_key_to_index) == 0 def test_given_partition_limit_exceeded_when_close_partition_then_switch_to_global_cursor(): @@ -3714,18 +3727,20 @@ def test_semaphore_cleanup(): # Verify initial state assert len(cursor._semaphore_per_partition) == 2 assert len(cursor._partition_parent_state_map) == 2 - assert cursor._partition_parent_state_map['{"id":"1"}'] == {"parent": {"state": "state1"}} - assert cursor._partition_parent_state_map['{"id":"2"}'] == {"parent": {"state": "state2"}} + assert len(cursor._processing_partitions_indexes) == 2 + assert len(cursor._partition_key_to_index) == 2 + assert cursor._partition_parent_state_map['{"id":"1"}'][0] == {"parent": {"state": "state1"}} + assert cursor._partition_parent_state_map['{"id":"2"}'][0] == {"parent": {"state": "state2"}} # Close partitions to acquire semaphores (value back to 0) for s in generated_slices: cursor.close_partition(DeclarativePartition("test_stream", {}, MagicMock(), MagicMock(), s)) # Check state after closing partitions - assert len(cursor._finished_partitions) == 2 + assert len(cursor._partitions_done_generating_stream_slices) == 0 assert len(cursor._semaphore_per_partition) == 0 - assert '{"id":"1"}' not in cursor._semaphore_per_partition - assert '{"id":"2"}' not in cursor._semaphore_per_partition + assert len(cursor._processing_partitions_indexes) == 0 + assert len(cursor._partition_key_to_index) == 0 assert len(cursor._partition_parent_state_map) == 0 # All parent states should be popped assert cursor._parent_state == {"parent": {"state": "state2"}} # Last parent state @@ -3773,3 +3788,196 @@ def test_given_global_state_when_read_then_state_is_not_per_partition() -> None: "use_global_cursor": True, # ensures that it is running the Concurrent CDK version as this is not populated in the declarative implementation }, # this state does have per partition which would be under `states` ) + + +def _make_inner_cursor(ts: str) -> MagicMock: + """Return an inner cursor that yields exactly one slice and has a proper state.""" + inner = MagicMock() + inner.stream_slices.side_effect = lambda: iter([{"dummy": "slice"}]) + inner.state = {"updated_at": ts} + inner.close_partition.return_value = None + inner.observe.return_value = None + return inner + + +def test_duplicate_partition_after_closing_partition_cursor_deleted(): + inner_cursors = [ + _make_inner_cursor("2024-01-01T00:00:00Z"), # for first "1" + _make_inner_cursor("2024-01-02T00:00:00Z"), # for "2" + _make_inner_cursor("2024-01-03T00:00:00Z"), # for second "1" + ] + cursor_factory_mock = MagicMock() + cursor_factory_mock.create.side_effect = inner_cursors + + converter = CustomFormatConcurrentStreamStateConverter( + datetime_format="%Y-%m-%dT%H:%M:%SZ", + input_datetime_formats=["%Y-%m-%dT%H:%M:%SZ"], + is_sequential_state=True, + cursor_granularity=timedelta(0), + ) + + cursor = ConcurrentPerPartitionCursor( + cursor_factory=cursor_factory_mock, + partition_router=MagicMock(), + stream_name="dup_stream", + stream_namespace=None, + stream_state={}, + message_repository=MagicMock(), + connector_state_manager=MagicMock(), + connector_state_converter=converter, + cursor_field=CursorField(cursor_field_key="updated_at"), + ) + + cursor.DEFAULT_MAX_PARTITIONS_NUMBER = 1 + + # ── Partition sequence: 1 → 2 → 1 ────────────────────────────────── + partitions = [ + StreamSlice(partition={"id": "1"}, cursor_slice={}, extra_fields={}), + StreamSlice(partition={"id": "2"}, cursor_slice={}, extra_fields={}), + StreamSlice(partition={"id": "1"}, cursor_slice={}, extra_fields={}), + ] + pr = cursor._partition_router + pr.stream_slices.return_value = iter(partitions) + pr.get_stream_state.return_value = {} + + # Iterate lazily so that the first "1" gets cleaned before + # the second "1" arrives. + slice_gen = cursor.stream_slices() + + first_1 = next(slice_gen) + cursor.close_partition( + DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), first_1) + ) + + two = next(slice_gen) + cursor.close_partition(DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), two)) + + second_1 = next(slice_gen) + cursor.close_partition( + DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), second_1) + ) + + assert cursor._IS_PARTITION_DUPLICATION_LOGGED is False # No duplicate detected + assert len(cursor._semaphore_per_partition) == 0 + assert len(cursor._processing_partitions_indexes) == 0 + assert len(cursor._partition_key_to_index) == 0 + assert len(cursor._partitions_done_generating_stream_slices) == 0 + + +def test_duplicate_partition_after_closing_partition_cursor_exists(): + inner_cursors = [ + _make_inner_cursor("2024-01-01T00:00:00Z"), # for first "1" + _make_inner_cursor("2024-01-02T00:00:00Z"), # for "2" + _make_inner_cursor("2024-01-03T00:00:00Z"), # for second "1" + ] + cursor_factory_mock = MagicMock() + cursor_factory_mock.create.side_effect = inner_cursors + + converter = CustomFormatConcurrentStreamStateConverter( + datetime_format="%Y-%m-%dT%H:%M:%SZ", + input_datetime_formats=["%Y-%m-%dT%H:%M:%SZ"], + is_sequential_state=True, + cursor_granularity=timedelta(0), + ) + + cursor = ConcurrentPerPartitionCursor( + cursor_factory=cursor_factory_mock, + partition_router=MagicMock(), + stream_name="dup_stream", + stream_namespace=None, + stream_state={}, + message_repository=MagicMock(), + connector_state_manager=MagicMock(), + connector_state_converter=converter, + cursor_field=CursorField(cursor_field_key="updated_at"), + ) + + # ── Partition sequence: 1 → 2 → 1 ────────────────────────────────── + partitions = [ + StreamSlice(partition={"id": "1"}, cursor_slice={}, extra_fields={}), + StreamSlice(partition={"id": "2"}, cursor_slice={}, extra_fields={}), + StreamSlice(partition={"id": "1"}, cursor_slice={}, extra_fields={}), + ] + pr = cursor._partition_router + pr.stream_slices.return_value = iter(partitions) + pr.get_stream_state.return_value = {} + + # Iterate lazily so that the first "1" gets cleaned before + # the second "1" arrives. + slice_gen = cursor.stream_slices() + + first_1 = next(slice_gen) + cursor.close_partition( + DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), first_1) + ) + + two = next(slice_gen) + cursor.close_partition(DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), two)) + + # Second “1” should appear because the semaphore was cleaned up + second_1 = next(slice_gen) + cursor.close_partition( + DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), second_1) + ) + + with pytest.raises(StopIteration): + next(slice_gen) + + assert cursor._IS_PARTITION_DUPLICATION_LOGGED is False # no duplicate warning + assert len(cursor._cursor_per_partition) == 2 # only “1” & “2” kept + assert len(cursor._semaphore_per_partition) == 0 # all semaphores cleaned + assert len(cursor._processing_partitions_indexes) == 0 + assert len(cursor._partition_key_to_index) == 0 + assert len(cursor._partitions_done_generating_stream_slices) == 0 + + +def test_duplicate_partition_while_processing(): + inner_cursors = [ + _make_inner_cursor("2024-01-01T00:00:00Z"), # first “1” + _make_inner_cursor("2024-01-02T00:00:00Z"), # “2” + _make_inner_cursor("2024-01-03T00:00:00Z"), # for second "1" + ] + + factory = MagicMock() + factory.create.side_effect = inner_cursors + + cursor = ConcurrentPerPartitionCursor( + cursor_factory=factory, + partition_router=MagicMock(), + stream_name="dup_stream", + stream_namespace=None, + stream_state={}, + message_repository=MagicMock(), + connector_state_manager=MagicMock(), + connector_state_converter=MagicMock(), + cursor_field=CursorField(cursor_field_key="updated_at"), + ) + + partitions = [ + StreamSlice(partition={"id": "1"}, cursor_slice={}, extra_fields={}), + StreamSlice(partition={"id": "2"}, cursor_slice={}, extra_fields={}), + StreamSlice(partition={"id": "1"}, cursor_slice={}, extra_fields={}), + ] + pr = cursor._partition_router + pr.stream_slices.return_value = iter(partitions) + pr.get_stream_state.return_value = {} + + generated = list(cursor.stream_slices()) + # Only “1” and “2” emitted – duplicate “1” skipped + assert len(generated) == 2 + + # Close “2” first + cursor.close_partition( + DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), generated[1]) + ) + # Now close the initial “1” + cursor.close_partition( + DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), generated[0]) + ) + + assert cursor._IS_PARTITION_DUPLICATION_LOGGED is True # warning emitted + assert len(cursor._cursor_per_partition) == 2 + assert len(cursor._semaphore_per_partition) == 0 + assert len(cursor._processing_partitions_indexes) == 0 + assert len(cursor._partition_key_to_index) == 0 + assert len(cursor._partitions_done_generating_stream_slices) == 0