Skip to content

Commit 57e1b52

Browse files
authored
fix(Low-Code Concurrent CDK): Refactor the low-code AsyncRetriever to use an underlying StreamSlicer (#170)
1 parent 9563c33 commit 57e1b52

File tree

7 files changed

+295
-45
lines changed

7 files changed

+295
-45
lines changed

airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,9 @@
329329
SinglePartitionRouter,
330330
SubstreamPartitionRouter,
331331
)
332+
from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import (
333+
AsyncJobPartitionRouter,
334+
)
332335
from airbyte_cdk.sources.declarative.partition_routers.substream_partition_router import (
333336
ParentStreamConfig,
334337
)
@@ -2260,22 +2263,28 @@ def create_async_retriever(
22602263
urls_extractor=urls_extractor,
22612264
)
22622265

2263-
return AsyncRetriever(
2266+
async_job_partition_router = AsyncJobPartitionRouter(
22642267
job_orchestrator_factory=lambda stream_slices: AsyncJobOrchestrator(
22652268
job_repository,
22662269
stream_slices,
2267-
JobTracker(
2268-
1
2269-
), # FIXME eventually make the number of concurrent jobs in the API configurable. Until then, we limit to 1
2270+
JobTracker(1),
2271+
# FIXME eventually make the number of concurrent jobs in the API configurable. Until then, we limit to 1
22702272
self._message_repository,
2271-
has_bulk_parent=False, # FIXME work would need to be done here in order to detect if a stream as a parent stream that is bulk
2273+
has_bulk_parent=False,
2274+
# FIXME work would need to be done here in order to detect if a stream as a parent stream that is bulk
22722275
),
2273-
record_selector=record_selector,
22742276
stream_slicer=stream_slicer,
22752277
config=config,
22762278
parameters=model.parameters or {},
22772279
)
22782280

2281+
return AsyncRetriever(
2282+
record_selector=record_selector,
2283+
stream_slicer=async_job_partition_router,
2284+
config=config,
2285+
parameters=model.parameters or {},
2286+
)
2287+
22792288
@staticmethod
22802289
def create_spec(model: SpecModel, config: Config, **kwargs: Any) -> Spec:
22812290
return Spec(

airbyte_cdk/sources/declarative/partition_routers/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,18 @@
22
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
33
#
44

5+
from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import AsyncJobPartitionRouter
56
from airbyte_cdk.sources.declarative.partition_routers.cartesian_product_stream_slicer import CartesianProductStreamSlicer
67
from airbyte_cdk.sources.declarative.partition_routers.list_partition_router import ListPartitionRouter
78
from airbyte_cdk.sources.declarative.partition_routers.single_partition_router import SinglePartitionRouter
89
from airbyte_cdk.sources.declarative.partition_routers.substream_partition_router import SubstreamPartitionRouter
910
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
1011

11-
__all__ = ["CartesianProductStreamSlicer", "ListPartitionRouter", "SinglePartitionRouter", "SubstreamPartitionRouter", "PartitionRouter"]
12+
__all__ = [
13+
"AsyncJobPartitionRouter",
14+
"CartesianProductStreamSlicer",
15+
"ListPartitionRouter",
16+
"SinglePartitionRouter",
17+
"SubstreamPartitionRouter",
18+
"PartitionRouter"
19+
]
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
2+
3+
from dataclasses import InitVar, dataclass, field
4+
from typing import Any, Callable, Iterable, Mapping, Optional
5+
6+
from airbyte_cdk.models import FailureType
7+
from airbyte_cdk.sources.declarative.async_job.job_orchestrator import (
8+
AsyncJobOrchestrator,
9+
AsyncPartition,
10+
)
11+
from airbyte_cdk.sources.declarative.partition_routers.single_partition_router import (
12+
SinglePartitionRouter,
13+
)
14+
from airbyte_cdk.sources.streams.concurrent.partitions.stream_slicer import StreamSlicer
15+
from airbyte_cdk.sources.types import Config, StreamSlice
16+
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
17+
18+
19+
@dataclass
20+
class AsyncJobPartitionRouter(StreamSlicer):
21+
"""
22+
Partition router that creates async jobs in a source API, periodically polls for job
23+
completion, and supplies the completed job URL locations as stream slices so that
24+
records can be extracted.
25+
"""
26+
27+
config: Config
28+
parameters: InitVar[Mapping[str, Any]]
29+
job_orchestrator_factory: Callable[[Iterable[StreamSlice]], AsyncJobOrchestrator]
30+
stream_slicer: StreamSlicer = field(
31+
default_factory=lambda: SinglePartitionRouter(parameters={})
32+
)
33+
34+
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
35+
self._job_orchestrator_factory = self.job_orchestrator_factory
36+
self._job_orchestrator: Optional[AsyncJobOrchestrator] = None
37+
self._parameters = parameters
38+
39+
def stream_slices(self) -> Iterable[StreamSlice]:
40+
slices = self.stream_slicer.stream_slices()
41+
self._job_orchestrator = self._job_orchestrator_factory(slices)
42+
43+
for completed_partition in self._job_orchestrator.create_and_get_completed_partitions():
44+
yield StreamSlice(
45+
partition=dict(completed_partition.stream_slice.partition)
46+
| {"partition": completed_partition},
47+
cursor_slice=completed_partition.stream_slice.cursor_slice,
48+
)
49+
50+
def fetch_records(self, partition: AsyncPartition) -> Iterable[Mapping[str, Any]]:
51+
"""
52+
This method of fetching records extends beyond what a PartitionRouter/StreamSlicer should
53+
be responsible for. However, this was added in because the JobOrchestrator is required to
54+
retrieve records. And without defining fetch_records() on this class, we're stuck with either
55+
passing the JobOrchestrator to the AsyncRetriever or storing it on multiple classes.
56+
"""
57+
58+
if not self._job_orchestrator:
59+
raise AirbyteTracedException(
60+
message="Invalid state within AsyncJobRetriever. Please contact Airbyte Support",
61+
internal_message="AsyncPartitionRepository is expected to be accessed only after `stream_slices`",
62+
failure_type=FailureType.system_error,
63+
)
64+
65+
return self._job_orchestrator.fetch_records(partition=partition)

airbyte_cdk/sources/declarative/retrievers/async_retriever.py

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
22

33

4-
from dataclasses import InitVar, dataclass, field
5-
from typing import Any, Callable, Iterable, Mapping, Optional
4+
from dataclasses import InitVar, dataclass
5+
from typing import Any, Iterable, Mapping, Optional
66

77
from typing_extensions import deprecated
88

@@ -12,9 +12,10 @@
1212
AsyncPartition,
1313
)
1414
from airbyte_cdk.sources.declarative.extractors.record_selector import RecordSelector
15-
from airbyte_cdk.sources.declarative.partition_routers import SinglePartitionRouter
15+
from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import (
16+
AsyncJobPartitionRouter,
17+
)
1618
from airbyte_cdk.sources.declarative.retrievers import Retriever
17-
from airbyte_cdk.sources.declarative.stream_slicers import StreamSlicer
1819
from airbyte_cdk.sources.source import ExperimentalClassWarning
1920
from airbyte_cdk.sources.streams.core import StreamData
2021
from airbyte_cdk.sources.types import Config, StreamSlice, StreamState
@@ -29,15 +30,10 @@
2930
class AsyncRetriever(Retriever):
3031
config: Config
3132
parameters: InitVar[Mapping[str, Any]]
32-
job_orchestrator_factory: Callable[[Iterable[StreamSlice]], AsyncJobOrchestrator]
3333
record_selector: RecordSelector
34-
stream_slicer: StreamSlicer = field(
35-
default_factory=lambda: SinglePartitionRouter(parameters={})
36-
)
34+
stream_slicer: AsyncJobPartitionRouter
3735

3836
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
39-
self._job_orchestrator_factory = self.job_orchestrator_factory
40-
self.__job_orchestrator: Optional[AsyncJobOrchestrator] = None
4137
self._parameters = parameters
4238

4339
@property
@@ -54,17 +50,6 @@ def state(self, value: StreamState) -> None:
5450
"""
5551
pass
5652

57-
@property
58-
def _job_orchestrator(self) -> AsyncJobOrchestrator:
59-
if not self.__job_orchestrator:
60-
raise AirbyteTracedException(
61-
message="Invalid state within AsyncJobRetriever. Please contact Airbyte Support",
62-
internal_message="AsyncPartitionRepository is expected to be accessed only after `stream_slices`",
63-
failure_type=FailureType.system_error,
64-
)
65-
66-
return self.__job_orchestrator
67-
6853
def _get_stream_state(self) -> StreamState:
6954
"""
7055
Gets the current state of the stream.
@@ -99,15 +84,7 @@ def _validate_and_get_stream_slice_partition(
9984
return stream_slice["partition"] # type: ignore # stream_slice["partition"] has been added as an AsyncPartition as part of stream_slices
10085

10186
def stream_slices(self) -> Iterable[Optional[StreamSlice]]:
102-
slices = self.stream_slicer.stream_slices()
103-
self.__job_orchestrator = self._job_orchestrator_factory(slices)
104-
105-
for completed_partition in self._job_orchestrator.create_and_get_completed_partitions():
106-
yield StreamSlice(
107-
partition=dict(completed_partition.stream_slice.partition)
108-
| {"partition": completed_partition},
109-
cursor_slice=completed_partition.stream_slice.cursor_slice,
110-
)
87+
return self.stream_slicer.stream_slices()
11188

11289
def read_records(
11390
self,
@@ -116,7 +93,7 @@ def read_records(
11693
) -> Iterable[StreamData]:
11794
stream_state: StreamState = self._get_stream_state()
11895
partition: AsyncPartition = self._validate_and_get_stream_slice_partition(stream_slice)
119-
records: Iterable[Mapping[str, Any]] = self._job_orchestrator.fetch_records(partition)
96+
records: Iterable[Mapping[str, Any]] = self.stream_slicer.fetch_records(partition)
12097

12198
yield from self.record_selector.filter_and_transform(
12299
all_data=records,

unit_tests/sources/declarative/async_job/test_integration.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
from airbyte_cdk.sources.declarative.async_job.status import AsyncJobStatus
2121
from airbyte_cdk.sources.declarative.extractors.record_extractor import RecordExtractor
2222
from airbyte_cdk.sources.declarative.extractors.record_selector import RecordSelector
23+
from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import (
24+
AsyncJobPartitionRouter,
25+
)
2326
from airbyte_cdk.sources.declarative.retrievers.async_retriever import AsyncRetriever
2427
from airbyte_cdk.sources.declarative.schema import InlineSchemaLoader
2528
from airbyte_cdk.sources.declarative.stream_slicers import StreamSlicer
@@ -35,7 +38,7 @@
3538

3639
class MockAsyncJobRepository(AsyncJobRepository):
3740
def start(self, stream_slice: StreamSlice) -> AsyncJob:
38-
return AsyncJob("a_job_id", StreamSlice(partition={}, cursor_slice={}))
41+
return AsyncJob("a_job_id", stream_slice)
3942

4043
def update_jobs_status(self, jobs: Set[AsyncJob]) -> None:
4144
for job in jobs:
@@ -79,12 +82,16 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
7982
config={},
8083
parameters={},
8184
record_selector=noop_record_selector,
82-
stream_slicer=self._stream_slicer,
83-
job_orchestrator_factory=lambda stream_slices: AsyncJobOrchestrator(
84-
MockAsyncJobRepository(),
85-
stream_slices,
86-
JobTracker(_NO_LIMIT),
87-
self._message_repository,
85+
stream_slicer=AsyncJobPartitionRouter(
86+
stream_slicer=self._stream_slicer,
87+
job_orchestrator_factory=lambda stream_slices: AsyncJobOrchestrator(
88+
MockAsyncJobRepository(),
89+
stream_slices,
90+
JobTracker(_NO_LIMIT),
91+
self._message_repository,
92+
),
93+
config={},
94+
parameters={},
8895
),
8996
),
9097
config={},

0 commit comments

Comments
 (0)