diff --git a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py index 92f4bdc4b..d4ecc0084 100644 --- a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py +++ b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py @@ -475,10 +475,21 @@ def _get_retriever( # Also a temporary hack. In the legacy Stream implementation, as part of the read, # set_initial_state() is called to instantiate incoming state on the cursor. Although we no # longer rely on the legacy low-code cursor for concurrent checkpointing, low-code components - # like StopConditionPaginationStrategyDecorator and ClientSideIncrementalRecordFilterDecorator - # still rely on a DatetimeBasedCursor that is properly initialized with state. + # like StopConditionPaginationStrategyDecorator still rely on a DatetimeBasedCursor that is + # properly initialized with state. if retriever.cursor: retriever.cursor.set_initial_state(stream_state=stream_state) + + # Similar to above, the ClientSideIncrementalRecordFilterDecorator cursor is a separate instance + # from the one initialized on the SimpleRetriever, so it also must also have state initialized + # for semi-incremental streams using is_client_side_incremental to filter properly + if isinstance(retriever.record_selector, RecordSelector) and isinstance( + retriever.record_selector.record_filter, ClientSideIncrementalRecordFilterDecorator + ): + retriever.record_selector.record_filter._cursor.set_initial_state( + stream_state=stream_state + ) # type: ignore # After non-concurrent cursors are deprecated we can remove these cursor workarounds + # We zero it out here, but since this is a cursor reference, the state is still properly # instantiated for the other components that reference it retriever.cursor = None diff --git a/unit_tests/sources/declarative/test_concurrent_declarative_source.py b/unit_tests/sources/declarative/test_concurrent_declarative_source.py index 892876850..1877e11bb 100644 --- a/unit_tests/sources/declarative/test_concurrent_declarative_source.py +++ b/unit_tests/sources/declarative/test_concurrent_declarative_source.py @@ -32,6 +32,9 @@ ConcurrentDeclarativeSource, ) from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream +from airbyte_cdk.sources.declarative.extractors.record_filter import ( + ClientSideIncrementalRecordFilterDecorator, +) from airbyte_cdk.sources.declarative.partition_routers import AsyncJobPartitionRouter from airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator import ( StreamSlicerPartitionGenerator, @@ -1647,6 +1650,44 @@ def test_async_incremental_stream_uses_concurrent_cursor_with_state(): assert async_job_partition_router.stream_slicer._concurrent_state == expected_state +def test_stream_using_is_client_side_incremental_has_cursor_state(): + expected_cursor_value = "2024-07-01" + state = [ + AirbyteStateMessage( + type=AirbyteStateType.STREAM, + stream=AirbyteStreamState( + stream_descriptor=StreamDescriptor(name="locations", namespace=None), + stream_state=AirbyteStateBlob(updated_at=expected_cursor_value), + ), + ) + ] + + manifest_with_stream_state_interpolation = copy.deepcopy(_MANIFEST) + + # Enable semi-incremental on the locations stream + manifest_with_stream_state_interpolation["definitions"]["locations_stream"]["incremental_sync"][ + "is_client_side_incremental" + ] = True + + source = ConcurrentDeclarativeSource( + source_config=manifest_with_stream_state_interpolation, + config=_CONFIG, + catalog=_CATALOG, + state=state, + ) + concurrent_streams, synchronous_streams = source._group_streams(config=_CONFIG) + + locations_stream = concurrent_streams[2] + assert isinstance(locations_stream, DefaultStream) + + simple_retriever = locations_stream._stream_partition_generator._partition_factory._retriever + record_filter = simple_retriever.record_selector.record_filter + assert isinstance(record_filter, ClientSideIncrementalRecordFilterDecorator) + client_side_incremental_cursor_state = record_filter._cursor._cursor + + assert client_side_incremental_cursor_state == expected_cursor_value + + def create_wrapped_stream(stream: DeclarativeStream) -> Stream: slice_to_records_mapping = get_mocked_read_records_output(stream_name=stream.name)