Skip to content

fix(concurrent-cdk): Move the grouping of concurrent and synchronous streams into the read and discover commands instead of when initializing the source #130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 12 additions & 23 deletions airbyte_cdk/sources/declarative/concurrent_declarative_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,23 +86,10 @@ def __init__(
component_factory=component_factory,
)

# todo: We could remove state from initialization. Now that streams are grouped during the read(), a source
# no longer needs to store the original incoming state. But maybe there's an edge case?
self._state = state

self._concurrent_streams: Optional[List[AbstractStream]]
self._synchronous_streams: Optional[List[Stream]]

# If the connector command was SPEC, there is no incoming config, and we cannot instantiate streams because
# they might depend on it. Ideally we want to have a static method on this class to get the spec without
# any other arguments, but the existing entrypoint.py isn't designed to support this. Just noting this
# for our future improvements to the CDK.
if config:
self._concurrent_streams, self._synchronous_streams = self._group_streams(
config=config or {}
)
else:
self._concurrent_streams = None
self._synchronous_streams = None

concurrency_level_from_manifest = self._source_config.get("concurrency_level")
if concurrency_level_from_manifest:
concurrency_level_component = self._constructor.create_component(
Expand Down Expand Up @@ -136,17 +123,20 @@ def read(
logger: logging.Logger,
config: Mapping[str, Any],
catalog: ConfiguredAirbyteCatalog,
state: Optional[Union[List[AirbyteStateMessage]]] = None,
state: Optional[List[AirbyteStateMessage]] = None,
) -> Iterator[AirbyteMessage]:
# ConcurrentReadProcessor pops streams that are finished being read so before syncing, the names of the concurrent
# streams must be saved so that they can be removed from the catalog before starting synchronous streams
if self._concurrent_streams:
concurrent_streams, _ = self._group_streams(config=config)

# ConcurrentReadProcessor pops streams that are finished being read so before syncing, the names of
# the concurrent streams must be saved so that they can be removed from the catalog before starting
# synchronous streams
if len(concurrent_streams) > 0:
concurrent_stream_names = set(
[concurrent_stream.name for concurrent_stream in self._concurrent_streams]
[concurrent_stream.name for concurrent_stream in concurrent_streams]
)

selected_concurrent_streams = self._select_streams(
streams=self._concurrent_streams, configured_catalog=catalog
streams=concurrent_streams, configured_catalog=catalog
)
# It would appear that passing in an empty set of streams causes an infinite loop in ConcurrentReadProcessor.
# This is also evident in concurrent_source_adapter.py so I'll leave this out of scope to fix for now
Expand All @@ -165,8 +155,7 @@ def read(
yield from super().read(logger, config, filtered_catalog, state)

def discover(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteCatalog:
concurrent_streams = self._concurrent_streams or []
synchronous_streams = self._synchronous_streams or []
concurrent_streams, synchronous_streams = self._group_streams(config=config)
return AirbyteCatalog(
streams=[
stream.as_airbyte_stream() for stream in concurrent_streams + synchronous_streams
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_jsonl_decoder(requests_mock, response_body, expected_json):
def large_event_response_fixture():
data = {"email": "[email protected]"}
jsonl_string = f"{json.dumps(data)}\n"
lines_in_response = 2 # ≈ 58 MB of response
lines_in_response = 2_000_000 # ≈ 58 MB of response
dir_path = os.path.dirname(os.path.realpath(__file__))
file_path = f"{dir_path}/test_response.txt"
with open(file_path, "w") as file:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union
from unittest.mock import patch

import freezegun
import isodate
Expand Down Expand Up @@ -647,8 +648,7 @@ def test_group_streams():
source = ConcurrentDeclarativeSource(
source_config=_MANIFEST, config=_CONFIG, catalog=catalog, state=state
)
concurrent_streams = source._concurrent_streams
synchronous_streams = source._synchronous_streams
concurrent_streams, synchronous_streams = source._group_streams(config=_CONFIG)

# 1 full refresh stream, 2 incremental streams, 1 substream w/o incremental, 1 list based substream w/o incremental
assert len(concurrent_streams) == 5
Expand Down Expand Up @@ -705,8 +705,9 @@ def test_create_concurrent_cursor():
source = ConcurrentDeclarativeSource(
source_config=_MANIFEST, config=_CONFIG, catalog=_CATALOG, state=state
)
concurrent_streams, synchronous_streams = source._group_streams(config=_CONFIG)

party_members_stream = source._concurrent_streams[0]
party_members_stream = concurrent_streams[0]
assert isinstance(party_members_stream, DefaultStream)
party_members_cursor = party_members_stream.cursor

Expand All @@ -722,7 +723,7 @@ def test_create_concurrent_cursor():
assert party_members_cursor._lookback_window == timedelta(days=5)
assert party_members_cursor._cursor_granularity == timedelta(days=1)

locations_stream = source._concurrent_streams[2]
locations_stream = concurrent_streams[2]
assert isinstance(locations_stream, DefaultStream)
locations_cursor = locations_stream.cursor

Expand Down Expand Up @@ -866,7 +867,21 @@ def _mock_party_members_skills_requests(http_mocker: HttpMocker) -> None:
)


def mocked_init(self, is_sequential_state: bool = True):
"""
This method is used to patch the existing __init__() function and always set is_sequential_state to
false. This is required because we want to test the concurrent state format. And because streams are
created under the hood of the read/discover/check command, we have no way of setting the field without
patching __init__()
"""
self._is_sequential_state = False


@freezegun.freeze_time(_NOW)
@patch(
"airbyte_cdk.sources.streams.concurrent.state_converters.abstract_stream_state_converter.AbstractStreamStateConverter.__init__",
mocked_init,
)
def test_read_with_concurrent_and_synchronous_streams():
"""
Verifies that a ConcurrentDeclarativeSource processes concurrent streams followed by synchronous streams
Expand All @@ -879,7 +894,6 @@ def test_read_with_concurrent_and_synchronous_streams():
source = ConcurrentDeclarativeSource(
source_config=_MANIFEST, config=_CONFIG, catalog=_CATALOG, state=None
)
disable_emitting_sequential_state_messages(source=source)

with HttpMocker() as http_mocker:
_mock_party_members_requests(http_mocker, _NO_STATE_PARTY_MEMBERS_SLICES_AND_RESPONSES)
Expand Down Expand Up @@ -959,6 +973,10 @@ def test_read_with_concurrent_and_synchronous_streams():


@freezegun.freeze_time(_NOW)
@patch(
"airbyte_cdk.sources.streams.concurrent.state_converters.abstract_stream_state_converter.AbstractStreamStateConverter.__init__",
mocked_init,
)
def test_read_with_concurrent_and_synchronous_streams_with_concurrent_state():
"""
Verifies that a ConcurrentDeclarativeSource processes concurrent streams correctly using the incoming
Expand Down Expand Up @@ -1016,7 +1034,6 @@ def test_read_with_concurrent_and_synchronous_streams_with_concurrent_state():
source = ConcurrentDeclarativeSource(
source_config=_MANIFEST, config=_CONFIG, catalog=_CATALOG, state=state
)
disable_emitting_sequential_state_messages(source=source)

with HttpMocker() as http_mocker:
_mock_party_members_requests(http_mocker, party_members_slices_and_responses)
Expand Down Expand Up @@ -1080,6 +1097,10 @@ def test_read_with_concurrent_and_synchronous_streams_with_concurrent_state():


@freezegun.freeze_time(_NOW)
@patch(
"airbyte_cdk.sources.streams.concurrent.state_converters.abstract_stream_state_converter.AbstractStreamStateConverter.__init__",
mocked_init,
)
def test_read_with_concurrent_and_synchronous_streams_with_sequential_state():
"""
Verifies that a ConcurrentDeclarativeSource processes concurrent streams correctly using the incoming
Expand All @@ -1105,7 +1126,6 @@ def test_read_with_concurrent_and_synchronous_streams_with_sequential_state():
source = ConcurrentDeclarativeSource(
source_config=_MANIFEST, config=_CONFIG, catalog=_CATALOG, state=state
)
disable_emitting_sequential_state_messages(source=source)

party_members_slices_and_responses = _NO_STATE_PARTY_MEMBERS_SLICES_AND_RESPONSES + [
(
Expand Down Expand Up @@ -1204,6 +1224,10 @@ def test_read_with_concurrent_and_synchronous_streams_with_sequential_state():


@freezegun.freeze_time(_NOW)
@patch(
"airbyte_cdk.sources.streams.concurrent.state_converters.abstract_stream_state_converter.AbstractStreamStateConverter.__init__",
mocked_init,
)
def test_read_concurrent_with_failing_partition_in_the_middle():
"""
Verify that partial state is emitted when only some partitions are successful during a concurrent sync attempt
Expand Down Expand Up @@ -1236,7 +1260,6 @@ def test_read_concurrent_with_failing_partition_in_the_middle():
source = ConcurrentDeclarativeSource(
source_config=_MANIFEST, config=_CONFIG, catalog=catalog, state=[]
)
disable_emitting_sequential_state_messages(source=source)

location_slices = [
{"start": "2024-07-01", "end": "2024-07-31"},
Expand All @@ -1263,6 +1286,10 @@ def test_read_concurrent_with_failing_partition_in_the_middle():


@freezegun.freeze_time(_NOW)
@patch(
"airbyte_cdk.sources.streams.concurrent.state_converters.abstract_stream_state_converter.AbstractStreamStateConverter.__init__",
mocked_init,
)
def test_read_concurrent_skip_streams_not_in_catalog():
"""
Verifies that the ConcurrentDeclarativeSource only syncs streams that are specified in the incoming ConfiguredCatalog
Expand Down Expand Up @@ -1311,8 +1338,6 @@ def test_read_concurrent_skip_streams_not_in_catalog():
# palaces requests
http_mocker.get(HttpRequest("https://persona.metaverse.com/palaces"), _PALACES_RESPONSE)

disable_emitting_sequential_state_messages(source=source)

messages = list(
source.read(logger=source.logger, config=_CONFIG, catalog=catalog, state=[])
)
Expand Down Expand Up @@ -1429,11 +1454,12 @@ def test_streams_with_stream_state_interpolation_should_be_synchronous():
catalog=_CATALOG,
state=None,
)
concurrent_streams, synchronous_streams = source._group_streams(config=_CONFIG)

# 1 full refresh stream, 2 with parent stream without incremental dependency
assert len(source._concurrent_streams) == 3
assert len(concurrent_streams) == 3
# 2 incremental stream with interpolation on state (locations and party_members), 1 incremental with parent stream (palace_enemies), 1 stream with async retriever
assert len(source._synchronous_streams) == 4
assert len(synchronous_streams) == 4


def test_given_partition_routing_and_incremental_sync_then_stream_is_not_concurrent():
Expand Down Expand Up @@ -1569,9 +1595,10 @@ def test_given_partition_routing_and_incremental_sync_then_stream_is_not_concurr
source = ConcurrentDeclarativeSource(
source_config=manifest, config=_CONFIG, catalog=catalog, state=state
)
concurrent_streams, synchronous_streams = source._group_streams(config=_CONFIG)

assert len(source._concurrent_streams) == 0
assert len(source._synchronous_streams) == 1
assert len(concurrent_streams) == 0
assert len(synchronous_streams) == 1


def create_wrapped_stream(stream: DeclarativeStream) -> Stream:
Expand Down Expand Up @@ -1725,9 +1752,3 @@ def get_states_for_stream(
for message in messages
if message.state and message.state.stream.stream_descriptor.name == stream_name
]


def disable_emitting_sequential_state_messages(source: ConcurrentDeclarativeSource) -> None:
for concurrent_stream in source._concurrent_streams: # type: ignore # This is the easiest way to disable behavior from the test
if isinstance(concurrent_stream.cursor, ConcurrentCursor):
concurrent_stream.cursor._connector_state_converter._is_sequential_state = False # type: ignore # see above
Loading