From ee537afe0011c01d5124f1c0c556a8b5ff8ad70e Mon Sep 17 00:00:00 2001 From: Artem Inzhyyants <36314070+artem1205@users.noreply.github.com> Date: Thu, 30 Jan 2025 21:40:16 +0100 Subject: [PATCH] feat: use create_concurrent_cursor_from_perpartition_cursor (#286) Signed-off-by: Artem Inzhyyants --- .../declarative/async_job/job_orchestrator.py | 8 +++--- .../concurrent_declarative_source.py | 3 ++- .../sources/declarative/declarative_stream.py | 4 ++- .../parsers/model_to_component_factory.py | 27 ++++++++++++++++++- .../async_job_partition_router.py | 10 +++---- .../declarative/retrievers/async_retriever.py | 18 +++++-------- .../async_job/test_job_orchestrator.py | 3 +-- .../test_async_job_partition_router.py | 20 ++++++-------- 8 files changed, 55 insertions(+), 38 deletions(-) diff --git a/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py b/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py index 3938b8c07..398cee9ff 100644 --- a/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py +++ b/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py @@ -482,16 +482,16 @@ def _is_breaking_exception(self, exception: Exception) -> bool: and exception.failure_type == FailureType.config_error ) - def fetch_records(self, partition: AsyncPartition) -> Iterable[Mapping[str, Any]]: + def fetch_records(self, async_jobs: Iterable[AsyncJob]) -> Iterable[Mapping[str, Any]]: """ - Fetches records from the given partition's jobs. + Fetches records from the given jobs. Args: - partition (AsyncPartition): The partition containing the jobs. + async_jobs Iterable[AsyncJob]: The list of AsyncJobs. Yields: Iterable[Mapping[str, Any]]: The fetched records from the jobs. """ - for job in partition.jobs: + for job in async_jobs: yield from self._job_repository.fetch_records(job) self._job_repository.delete(job) diff --git a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py index 3293731fd..92f4bdc4b 100644 --- a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py +++ b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py @@ -19,6 +19,7 @@ from airbyte_cdk.sources.declarative.extractors.record_filter import ( ClientSideIncrementalRecordFilterDecorator, ) +from airbyte_cdk.sources.declarative.incremental import ConcurrentPerPartitionCursor from airbyte_cdk.sources.declarative.incremental.datetime_based_cursor import DatetimeBasedCursor from airbyte_cdk.sources.declarative.incremental.per_partition_with_global import ( PerPartitionWithGlobalCursor, @@ -231,7 +232,7 @@ def _group_streams( ): cursor = declarative_stream.retriever.stream_slicer.stream_slicer - if not isinstance(cursor, ConcurrentCursor): + if not isinstance(cursor, ConcurrentCursor | ConcurrentPerPartitionCursor): # This should never happen since we instantiate ConcurrentCursor in # model_to_component_factory.py raise ValueError( diff --git a/airbyte_cdk/sources/declarative/declarative_stream.py b/airbyte_cdk/sources/declarative/declarative_stream.py index 12cdd3337..f7b97f3b4 100644 --- a/airbyte_cdk/sources/declarative/declarative_stream.py +++ b/airbyte_cdk/sources/declarative/declarative_stream.py @@ -138,7 +138,9 @@ def read_records( """ :param: stream_state We knowingly avoid using stream_state as we want cursors to manage their own state. """ - if stream_slice is None or stream_slice == {}: + if stream_slice is None or ( + not isinstance(stream_slice, StreamSlice) and stream_slice == {} + ): # As the parameter is Optional, many would just call `read_records(sync_mode)` during testing without specifying the field # As part of the declarative model without custom components, this should never happen as the CDK would wire up a # SinglePartitionRouter that would create this StreamSlice properly diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index a8736986e..b8eeca1ec 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -1656,7 +1656,7 @@ def _build_stream_slicer_from_partition_router( ) -> Optional[PartitionRouter]: if ( hasattr(model, "partition_router") - and isinstance(model, SimpleRetrieverModel) + and isinstance(model, SimpleRetrieverModel | AsyncRetrieverModel) and model.partition_router ): stream_slicer_model = model.partition_router @@ -1690,6 +1690,31 @@ def _merge_stream_slicers( stream_slicer = self._build_stream_slicer_from_partition_router(model.retriever, config) if model.incremental_sync and stream_slicer: + if model.retriever.type == "AsyncRetriever": + if model.incremental_sync.type != "DatetimeBasedCursor": + # We are currently in a transition to the Concurrent CDK and AsyncRetriever can only work with the support or unordered slices (for example, when we trigger reports for January and February, the report in February can be completed first). Once we have support for custom concurrent cursor or have a new implementation available in the CDK, we can enable more cursors here. + raise ValueError( + "AsyncRetriever with cursor other than DatetimeBasedCursor is not supported yet" + ) + if stream_slicer: + return self.create_concurrent_cursor_from_perpartition_cursor( # type: ignore # This is a known issue that we are creating and returning a ConcurrentCursor which does not technically implement the (low-code) StreamSlicer. However, (low-code) StreamSlicer and ConcurrentCursor both implement StreamSlicer.stream_slices() which is the primary method needed for checkpointing + state_manager=self._connector_state_manager, + model_type=DatetimeBasedCursorModel, + component_definition=model.incremental_sync.__dict__, + stream_name=model.name or "", + stream_namespace=None, + config=config or {}, + stream_state={}, + partition_router=stream_slicer, + ) + return self.create_concurrent_cursor_from_datetime_based_cursor( # type: ignore # This is a known issue that we are creating and returning a ConcurrentCursor which does not technically implement the (low-code) StreamSlicer. However, (low-code) StreamSlicer and ConcurrentCursor both implement StreamSlicer.stream_slices() which is the primary method needed for checkpointing + model_type=DatetimeBasedCursorModel, + component_definition=model.incremental_sync.__dict__, + stream_name=model.name or "", + stream_namespace=None, + config=config or {}, + ) + incremental_sync_model = model.incremental_sync if ( hasattr(incremental_sync_model, "global_substream_cursor") diff --git a/airbyte_cdk/sources/declarative/partition_routers/async_job_partition_router.py b/airbyte_cdk/sources/declarative/partition_routers/async_job_partition_router.py index 0f11820f7..38a4f5328 100644 --- a/airbyte_cdk/sources/declarative/partition_routers/async_job_partition_router.py +++ b/airbyte_cdk/sources/declarative/partition_routers/async_job_partition_router.py @@ -4,9 +4,9 @@ from typing import Any, Callable, Iterable, Mapping, Optional from airbyte_cdk.models import FailureType +from airbyte_cdk.sources.declarative.async_job.job import AsyncJob from airbyte_cdk.sources.declarative.async_job.job_orchestrator import ( AsyncJobOrchestrator, - AsyncPartition, ) from airbyte_cdk.sources.declarative.partition_routers.single_partition_router import ( SinglePartitionRouter, @@ -42,12 +42,12 @@ def stream_slices(self) -> Iterable[StreamSlice]: for completed_partition in self._job_orchestrator.create_and_get_completed_partitions(): yield StreamSlice( - partition=dict(completed_partition.stream_slice.partition) - | {"partition": completed_partition}, + partition=dict(completed_partition.stream_slice.partition), cursor_slice=completed_partition.stream_slice.cursor_slice, + extra_fields={"jobs": list(completed_partition.jobs)}, ) - def fetch_records(self, partition: AsyncPartition) -> Iterable[Mapping[str, Any]]: + def fetch_records(self, async_jobs: Iterable[AsyncJob]) -> Iterable[Mapping[str, Any]]: """ This method of fetching records extends beyond what a PartitionRouter/StreamSlicer should be responsible for. However, this was added in because the JobOrchestrator is required to @@ -62,4 +62,4 @@ def fetch_records(self, partition: AsyncPartition) -> Iterable[Mapping[str, Any] failure_type=FailureType.system_error, ) - return self._job_orchestrator.fetch_records(partition=partition) + return self._job_orchestrator.fetch_records(async_jobs=async_jobs) diff --git a/airbyte_cdk/sources/declarative/retrievers/async_retriever.py b/airbyte_cdk/sources/declarative/retrievers/async_retriever.py index bd28e0e2d..24f52cfd3 100644 --- a/airbyte_cdk/sources/declarative/retrievers/async_retriever.py +++ b/airbyte_cdk/sources/declarative/retrievers/async_retriever.py @@ -6,7 +6,7 @@ from typing_extensions import deprecated -from airbyte_cdk.models import FailureType +from airbyte_cdk.sources.declarative.async_job.job import AsyncJob from airbyte_cdk.sources.declarative.async_job.job_orchestrator import AsyncPartition from airbyte_cdk.sources.declarative.extractors.record_selector import RecordSelector from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import ( @@ -16,7 +16,6 @@ from airbyte_cdk.sources.source import ExperimentalClassWarning from airbyte_cdk.sources.streams.core import StreamData from airbyte_cdk.sources.types import Config, StreamSlice, StreamState -from airbyte_cdk.utils.traced_exception import AirbyteTracedException @deprecated( @@ -57,9 +56,9 @@ def _get_stream_state(self) -> StreamState: return self.state - def _validate_and_get_stream_slice_partition( + def _validate_and_get_stream_slice_jobs( self, stream_slice: Optional[StreamSlice] = None - ) -> AsyncPartition: + ) -> Iterable[AsyncJob]: """ Validates the stream_slice argument and returns the partition from it. @@ -73,12 +72,7 @@ def _validate_and_get_stream_slice_partition( AirbyteTracedException: If the stream_slice is not an instance of StreamSlice or if the partition is not present in the stream_slice. """ - if not isinstance(stream_slice, StreamSlice) or "partition" not in stream_slice.partition: - raise AirbyteTracedException( - message="Invalid arguments to AsyncRetriever.read_records: stream_slice is not optional. Please contact Airbyte Support", - failure_type=FailureType.system_error, - ) - return stream_slice["partition"] # type: ignore # stream_slice["partition"] has been added as an AsyncPartition as part of stream_slices + return stream_slice.extra_fields.get("jobs", []) if stream_slice else [] def stream_slices(self) -> Iterable[Optional[StreamSlice]]: return self.stream_slicer.stream_slices() @@ -89,8 +83,8 @@ def read_records( stream_slice: Optional[StreamSlice] = None, ) -> Iterable[StreamData]: stream_state: StreamState = self._get_stream_state() - partition: AsyncPartition = self._validate_and_get_stream_slice_partition(stream_slice) - records: Iterable[Mapping[str, Any]] = self.stream_slicer.fetch_records(partition) + jobs: Iterable[AsyncJob] = self._validate_and_get_stream_slice_jobs(stream_slice) + records: Iterable[Mapping[str, Any]] = self.stream_slicer.fetch_records(jobs) yield from self.record_selector.filter_and_transform( all_data=records, diff --git a/unit_tests/sources/declarative/async_job/test_job_orchestrator.py b/unit_tests/sources/declarative/async_job/test_job_orchestrator.py index d2fb9018f..dc81eacbc 100644 --- a/unit_tests/sources/declarative/async_job/test_job_orchestrator.py +++ b/unit_tests/sources/declarative/async_job/test_job_orchestrator.py @@ -174,9 +174,8 @@ def test_when_fetch_records_then_yield_records_from_each_job(self) -> None: orchestrator = self._orchestrator([_A_STREAM_SLICE]) first_job = _create_job() second_job = _create_job() - partition = AsyncPartition([first_job, second_job], _A_STREAM_SLICE) - records = list(orchestrator.fetch_records(partition)) + records = list(orchestrator.fetch_records([first_job, second_job])) assert len(records) == 2 assert self._job_repository.fetch_records.mock_calls == [call(first_job), call(second_job)] diff --git a/unit_tests/sources/declarative/partition_routers/test_async_job_partition_router.py b/unit_tests/sources/declarative/partition_routers/test_async_job_partition_router.py index ccc57cc91..2a5ac3277 100644 --- a/unit_tests/sources/declarative/partition_routers/test_async_job_partition_router.py +++ b/unit_tests/sources/declarative/partition_routers/test_async_job_partition_router.py @@ -35,12 +35,12 @@ def test_stream_slices_with_single_partition_router(): slices = list(partition_router.stream_slices()) assert len(slices) == 1 - partition = slices[0].partition.get("partition") - assert isinstance(partition, AsyncPartition) - assert partition.stream_slice == StreamSlice(partition={}, cursor_slice={}) - assert partition.status == AsyncJobStatus.COMPLETED + partition = slices[0] + assert isinstance(partition, StreamSlice) + assert partition == StreamSlice(partition={}, cursor_slice={}) + assert partition.extra_fields["jobs"][0].status() == AsyncJobStatus.COMPLETED - attempts_per_job = list(partition.jobs) + attempts_per_job = list(partition.extra_fields["jobs"]) assert len(attempts_per_job) == 1 assert attempts_per_job[0].api_job_id() == "a_job_id" assert attempts_per_job[0].job_parameters() == StreamSlice(partition={}, cursor_slice={}) @@ -68,14 +68,10 @@ def test_stream_slices_with_parent_slicer(): slices = list(partition_router.stream_slices()) assert len(slices) == 3 for i, partition in enumerate(slices): - partition = partition.partition.get("partition") - assert isinstance(partition, AsyncPartition) - assert partition.stream_slice == StreamSlice( - partition={"parent_id": str(i)}, cursor_slice={} - ) - assert partition.status == AsyncJobStatus.COMPLETED + assert isinstance(partition, StreamSlice) + assert partition == StreamSlice(partition={"parent_id": str(i)}, cursor_slice={}) - attempts_per_job = list(partition.jobs) + attempts_per_job = list(partition.extra_fields["jobs"]) assert len(attempts_per_job) == 1 assert attempts_per_job[0].api_job_id() == "a_job_id" assert attempts_per_job[0].job_parameters() == StreamSlice(