Skip to content

Commit e27cb81

Browse files
maxi297aaronsteersoctavia-squidington-iiibrianjlai
authored
chore(refactor): refactor partition generator to take any stream slicer (#39)
Co-authored-by: Aaron ("AJ") Steers <[email protected]> Co-authored-by: octavia-squidington-iii <[email protected]> Co-authored-by: Brian Lai <[email protected]>
1 parent e808271 commit e27cb81

File tree

13 files changed

+552
-295
lines changed

13 files changed

+552
-295
lines changed

airbyte_cdk/sources/declarative/concurrent_declarative_source.py

+52-28
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#
44

55
import logging
6-
from typing import Any, Generic, Iterator, List, Mapping, Optional, Tuple, Union
6+
from typing import Any, Generic, Iterator, List, Mapping, Optional, Tuple, Union, Callable
77

88
from airbyte_cdk.models import (
99
AirbyteCatalog,
@@ -27,18 +27,24 @@
2727
)
2828
from airbyte_cdk.sources.declarative.models.declarative_component_schema import (
2929
DatetimeBasedCursor as DatetimeBasedCursorModel,
30+
DeclarativeStream as DeclarativeStreamModel,
3031
)
3132
from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import (
3233
ModelToComponentFactory,
34+
ComponentDefinition,
3335
)
3436
from airbyte_cdk.sources.declarative.requesters import HttpRequester
35-
from airbyte_cdk.sources.declarative.retrievers import SimpleRetriever
37+
from airbyte_cdk.sources.declarative.retrievers import SimpleRetriever, Retriever
38+
from airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator import (
39+
DeclarativePartitionFactory,
40+
StreamSlicerPartitionGenerator,
41+
)
3642
from airbyte_cdk.sources.declarative.transformations.add_fields import AddFields
3743
from airbyte_cdk.sources.declarative.types import ConnectionDefinition
3844
from airbyte_cdk.sources.source import TState
45+
from airbyte_cdk.sources.types import Config, StreamState
3946
from airbyte_cdk.sources.streams import Stream
4047
from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
41-
from airbyte_cdk.sources.streams.concurrent.adapters import CursorPartitionGenerator
4248
from airbyte_cdk.sources.streams.concurrent.availability_strategy import (
4349
AlwaysAvailableAvailabilityStrategy,
4450
)
@@ -213,31 +219,18 @@ def _group_streams(
213219
)
214220
)
215221

216-
# This is an optimization so that we don't invoke any cursor or state management flows within the
217-
# low-code framework because state management is handled through the ConcurrentCursor.
218-
if (
219-
declarative_stream
220-
and declarative_stream.retriever
221-
and isinstance(declarative_stream.retriever, SimpleRetriever)
222-
):
223-
# Also a temporary hack. In the legacy Stream implementation, as part of the read, set_initial_state() is
224-
# called to instantiate incoming state on the cursor. Although we no longer rely on the legacy low-code cursor
225-
# for concurrent checkpointing, low-code components like StopConditionPaginationStrategyDecorator and
226-
# ClientSideIncrementalRecordFilterDecorator still rely on a DatetimeBasedCursor that is properly initialized
227-
# with state.
228-
if declarative_stream.retriever.cursor:
229-
declarative_stream.retriever.cursor.set_initial_state(
230-
stream_state=stream_state
231-
)
232-
declarative_stream.retriever.cursor = None
233-
234-
partition_generator = CursorPartitionGenerator(
235-
stream=declarative_stream,
236-
message_repository=self.message_repository, # type: ignore # message_repository is always instantiated with a value by factory
237-
cursor=cursor,
238-
connector_state_converter=connector_state_converter,
239-
cursor_field=[cursor.cursor_field.cursor_field_key],
240-
slice_boundary_fields=cursor.slice_boundary_fields,
222+
partition_generator = StreamSlicerPartitionGenerator(
223+
DeclarativePartitionFactory(
224+
declarative_stream.name,
225+
declarative_stream.get_json_schema(),
226+
self._retriever_factory(
227+
name_to_stream_mapping[declarative_stream.name],
228+
config,
229+
stream_state,
230+
),
231+
self.message_repository,
232+
),
233+
cursor,
241234
)
242235

243236
concurrent_streams.append(
@@ -350,3 +343,34 @@ def _remove_concurrent_streams_from_catalog(
350343
if stream.stream.name not in concurrent_stream_names
351344
]
352345
)
346+
347+
def _retriever_factory(
348+
self, stream_config: ComponentDefinition, source_config: Config, stream_state: StreamState
349+
) -> Callable[[], Retriever]:
350+
def _factory_method() -> Retriever:
351+
declarative_stream: DeclarativeStream = self._constructor.create_component(
352+
DeclarativeStreamModel,
353+
stream_config,
354+
source_config,
355+
emit_connector_builder_messages=self._emit_connector_builder_messages,
356+
)
357+
358+
# This is an optimization so that we don't invoke any cursor or state management flows within the
359+
# low-code framework because state management is handled through the ConcurrentCursor.
360+
if (
361+
declarative_stream
362+
and declarative_stream.retriever
363+
and isinstance(declarative_stream.retriever, SimpleRetriever)
364+
):
365+
# Also a temporary hack. In the legacy Stream implementation, as part of the read, set_initial_state() is
366+
# called to instantiate incoming state on the cursor. Although we no longer rely on the legacy low-code cursor
367+
# for concurrent checkpointing, low-code components like StopConditionPaginationStrategyDecorator and
368+
# ClientSideIncrementalRecordFilterDecorator still rely on a DatetimeBasedCursor that is properly initialized
369+
# with state.
370+
if declarative_stream.retriever.cursor:
371+
declarative_stream.retriever.cursor.set_initial_state(stream_state=stream_state)
372+
declarative_stream.retriever.cursor = None
373+
374+
return declarative_stream.retriever
375+
376+
return _factory_method

airbyte_cdk/sources/declarative/manifest_declarative_source.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import re
99
from copy import deepcopy
1010
from importlib import metadata
11-
from typing import Any, Dict, Iterator, List, Mapping, Optional, Tuple, Union
11+
from typing import Any, Dict, Iterator, List, Mapping, Optional, Tuple
1212

1313
import yaml
1414
from airbyte_cdk.models import (
@@ -94,7 +94,7 @@ def resolved_manifest(self) -> Mapping[str, Any]:
9494
return self._source_config
9595

9696
@property
97-
def message_repository(self) -> Union[None, MessageRepository]:
97+
def message_repository(self) -> MessageRepository:
9898
return self._message_repository
9999

100100
@property
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
2+
3+
from typing import Iterable, Optional, Mapping, Any, Callable
4+
5+
from airbyte_cdk.sources.declarative.retrievers import Retriever
6+
from airbyte_cdk.sources.message import MessageRepository
7+
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
8+
from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator
9+
from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
10+
from airbyte_cdk.sources.streams.concurrent.partitions.stream_slicer import StreamSlicer
11+
from airbyte_cdk.sources.types import StreamSlice
12+
from airbyte_cdk.utils.slice_hasher import SliceHasher
13+
14+
15+
class DeclarativePartitionFactory:
16+
def __init__(
17+
self,
18+
stream_name: str,
19+
json_schema: Mapping[str, Any],
20+
retriever_factory: Callable[[], Retriever],
21+
message_repository: MessageRepository,
22+
) -> None:
23+
"""
24+
The DeclarativePartitionFactory takes a retriever_factory and not a retriever directly. The reason is that our components are not
25+
thread safe and classes like `DefaultPaginator` may not work because multiple threads can access and modify a shared field across each other.
26+
In order to avoid these problems, we will create one retriever per thread which should make the processing thread-safe.
27+
"""
28+
self._stream_name = stream_name
29+
self._json_schema = json_schema
30+
self._retriever_factory = retriever_factory
31+
self._message_repository = message_repository
32+
33+
def create(self, stream_slice: StreamSlice) -> Partition:
34+
return DeclarativePartition(
35+
self._stream_name,
36+
self._json_schema,
37+
self._retriever_factory(),
38+
self._message_repository,
39+
stream_slice,
40+
)
41+
42+
43+
class DeclarativePartition(Partition):
44+
def __init__(
45+
self,
46+
stream_name: str,
47+
json_schema: Mapping[str, Any],
48+
retriever: Retriever,
49+
message_repository: MessageRepository,
50+
stream_slice: StreamSlice,
51+
):
52+
self._stream_name = stream_name
53+
self._json_schema = json_schema
54+
self._retriever = retriever
55+
self._message_repository = message_repository
56+
self._stream_slice = stream_slice
57+
self._hash = SliceHasher.hash(self._stream_name, self._stream_slice)
58+
59+
def read(self) -> Iterable[Record]:
60+
for stream_data in self._retriever.read_records(self._json_schema, self._stream_slice):
61+
if isinstance(stream_data, Mapping):
62+
yield Record(stream_data, self)
63+
else:
64+
self._message_repository.emit_message(stream_data)
65+
66+
def to_slice(self) -> Optional[Mapping[str, Any]]:
67+
return self._stream_slice
68+
69+
def stream_name(self) -> str:
70+
return self._stream_name
71+
72+
def __hash__(self) -> int:
73+
return self._hash
74+
75+
76+
class StreamSlicerPartitionGenerator(PartitionGenerator):
77+
def __init__(
78+
self, partition_factory: DeclarativePartitionFactory, stream_slicer: StreamSlicer
79+
) -> None:
80+
self._partition_factory = partition_factory
81+
self._stream_slicer = stream_slicer
82+
83+
def generate(self) -> Iterable[Partition]:
84+
for stream_slice in self._stream_slicer.stream_slices():
85+
yield self._partition_factory.create(stream_slice)

airbyte_cdk/sources/declarative/stream_slicers/stream_slicer.py

+6-13
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,17 @@
22
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
33
#
44

5-
from abc import abstractmethod
6-
from dataclasses import dataclass
7-
from typing import Iterable
5+
from abc import ABC
86

97
from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import (
108
RequestOptionsProvider,
119
)
12-
from airbyte_cdk.sources.types import StreamSlice
10+
from airbyte_cdk.sources.streams.concurrent.partitions.stream_slicer import (
11+
StreamSlicer as ConcurrentStreamSlicer,
12+
)
1313

1414

15-
@dataclass
16-
class StreamSlicer(RequestOptionsProvider):
15+
class StreamSlicer(ConcurrentStreamSlicer, RequestOptionsProvider, ABC):
1716
"""
1817
Slices the stream into a subset of records.
1918
Slices enable state checkpointing and data retrieval parallelization.
@@ -23,10 +22,4 @@ class StreamSlicer(RequestOptionsProvider):
2322
See the stream slicing section of the docs for more information.
2423
"""
2524

26-
@abstractmethod
27-
def stream_slices(self) -> Iterable[StreamSlice]:
28-
"""
29-
Defines stream slices
30-
31-
:return: List of stream slices
32-
"""
25+
pass

airbyte_cdk/sources/streams/concurrent/adapters.py

+4-87
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,13 @@
3838
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
3939
from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator
4040
from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
41-
from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import (
42-
DateTimeStreamStateConverter,
43-
)
4441
from airbyte_cdk.sources.streams.core import StreamData
45-
from airbyte_cdk.sources.types import StreamSlice
4642
from airbyte_cdk.sources.utils.schema_helpers import InternalConfig
4743
from airbyte_cdk.sources.utils.slice_logger import SliceLogger
4844
from deprecated.classic import deprecated
4945

46+
from airbyte_cdk.utils.slice_hasher import SliceHasher
47+
5048
"""
5149
This module contains adapters to help enabling concurrency on Stream objects without needing to migrate to AbstractStream
5250
"""
@@ -270,6 +268,7 @@ def __init__(
270268
self._sync_mode = sync_mode
271269
self._cursor_field = cursor_field
272270
self._state = state
271+
self._hash = SliceHasher.hash(self._stream.name, self._slice)
273272

274273
def read(self) -> Iterable[Record]:
275274
"""
@@ -309,12 +308,7 @@ def to_slice(self) -> Optional[Mapping[str, Any]]:
309308
return self._slice
310309

311310
def __hash__(self) -> int:
312-
if self._slice:
313-
# Convert the slice to a string so that it can be hashed
314-
s = json.dumps(self._slice, sort_keys=True, cls=SliceEncoder)
315-
return hash((self._stream.name, s))
316-
else:
317-
return hash(self._stream.name)
311+
return self._hash
318312

319313
def stream_name(self) -> str:
320314
return self._stream.name
@@ -363,83 +357,6 @@ def generate(self) -> Iterable[Partition]:
363357
)
364358

365359

366-
class CursorPartitionGenerator(PartitionGenerator):
367-
"""
368-
This class generates partitions using the concurrent cursor and iterates through state slices to generate partitions.
369-
370-
It is used when synchronizing a stream in incremental or full-refresh mode where state information is maintained
371-
across partitions. Each partition represents a subset of the stream's data and is determined by the cursor's state.
372-
"""
373-
374-
_START_BOUNDARY = 0
375-
_END_BOUNDARY = 1
376-
377-
def __init__(
378-
self,
379-
stream: Stream,
380-
message_repository: MessageRepository,
381-
cursor: Cursor,
382-
connector_state_converter: DateTimeStreamStateConverter,
383-
cursor_field: Optional[List[str]],
384-
slice_boundary_fields: Optional[Tuple[str, str]],
385-
):
386-
"""
387-
Initialize the CursorPartitionGenerator with a stream, sync mode, and cursor.
388-
389-
:param stream: The stream to delegate to for partition generation.
390-
:param message_repository: The message repository to use to emit non-record messages.
391-
:param sync_mode: The synchronization mode.
392-
:param cursor: A Cursor object that maintains the state and the cursor field.
393-
"""
394-
self._stream = stream
395-
self.message_repository = message_repository
396-
self._sync_mode = SyncMode.full_refresh
397-
self._cursor = cursor
398-
self._cursor_field = cursor_field
399-
self._state = self._cursor.state
400-
self._slice_boundary_fields = slice_boundary_fields
401-
self._connector_state_converter = connector_state_converter
402-
403-
def generate(self) -> Iterable[Partition]:
404-
"""
405-
Generate partitions based on the slices in the cursor's state.
406-
407-
This method iterates through the list of slices found in the cursor's state, and for each slice, it generates
408-
a `StreamPartition` object.
409-
410-
:return: An iterable of StreamPartition objects.
411-
"""
412-
413-
start_boundary = (
414-
self._slice_boundary_fields[self._START_BOUNDARY]
415-
if self._slice_boundary_fields
416-
else "start"
417-
)
418-
end_boundary = (
419-
self._slice_boundary_fields[self._END_BOUNDARY]
420-
if self._slice_boundary_fields
421-
else "end"
422-
)
423-
424-
for slice_start, slice_end in self._cursor.generate_slices():
425-
stream_slice = StreamSlice(
426-
partition={},
427-
cursor_slice={
428-
start_boundary: self._connector_state_converter.output_format(slice_start),
429-
end_boundary: self._connector_state_converter.output_format(slice_end),
430-
},
431-
)
432-
433-
yield StreamPartition(
434-
self._stream,
435-
copy.deepcopy(stream_slice),
436-
self.message_repository,
437-
self._sync_mode,
438-
self._cursor_field,
439-
self._state,
440-
)
441-
442-
443360
@deprecated(
444361
"Availability strategy has been soft deprecated. Do not use. Class is subject to removal",
445362
category=ExperimentalClassWarning,

0 commit comments

Comments
 (0)