Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: (low code)run state migrations for concurrent streams #316

16 changes: 15 additions & 1 deletion airbyte_cdk/sources/declarative/concurrent_declarative_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#

import logging
from typing import Any, Generic, Iterator, List, Mapping, Optional, Tuple
from typing import Any, Generic, Iterator, List, Mapping, MutableMapping, Optional, Tuple

from airbyte_cdk.models import (
AirbyteCatalog,
Expand Down Expand Up @@ -224,6 +224,7 @@ def _group_streams(
stream_state = self._connector_state_manager.get_stream_state(
stream_name=declarative_stream.name, namespace=declarative_stream.namespace
)
stream_state = self._migrate_state(declarative_stream, stream_state)
darynaishchenko marked this conversation as resolved.
Show resolved Hide resolved

retriever = self._get_retriever(declarative_stream, stream_state)

Expand Down Expand Up @@ -331,6 +332,8 @@ def _group_streams(
stream_state = self._connector_state_manager.get_stream_state(
stream_name=declarative_stream.name, namespace=declarative_stream.namespace
)
stream_state = self._migrate_state(declarative_stream, stream_state)

partition_router = declarative_stream.retriever.stream_slicer._partition_router

perpartition_cursor = (
Expand Down Expand Up @@ -521,3 +524,14 @@ def _remove_concurrent_streams_from_catalog(
if stream.stream.name not in concurrent_stream_names
]
)

@staticmethod
def _migrate_state(
declarative_stream: DeclarativeStream, stream_state: MutableMapping[str, Any]
) -> MutableMapping[str, Any]:
for state_migration in declarative_stream.state_migrations:
if state_migration.should_migrate(stream_state):
# The state variable is expected to be mutable but the migrate method returns an immutable mapping.
stream_state = dict(state_migration.migrate(stream_state))

return stream_state
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,17 @@ def create_concurrency_level(
parameters={},
)

@staticmethod
def apply_stream_state_migrations(
stream_state_migrations: List[Any] | None, stream_state: MutableMapping[str, Any]
) -> MutableMapping[str, Any]:
if stream_state_migrations:
for state_migration in stream_state_migrations:
if state_migration.should_migrate(stream_state):
# The state variable is expected to be mutable but the migrate method returns an immutable mapping.
stream_state = dict(state_migration.migrate(stream_state))
return stream_state

def create_concurrent_cursor_from_datetime_based_cursor(
self,
model_type: Type[BaseModel],
Expand All @@ -943,6 +954,7 @@ def create_concurrent_cursor_from_datetime_based_cursor(
config: Config,
message_repository: Optional[MessageRepository] = None,
runtime_lookback_window: Optional[datetime.timedelta] = None,
stream_state_migrations: Optional[List[Any]] = None,
**kwargs: Any,
) -> ConcurrentCursor:
# Per-partition incremental streams can dynamically create child cursors which will pass their current
Expand All @@ -953,6 +965,7 @@ def create_concurrent_cursor_from_datetime_based_cursor(
if "stream_state" not in kwargs
else kwargs["stream_state"]
)
stream_state = self.apply_stream_state_migrations(stream_state_migrations, stream_state)

component_type = component_definition.get("type")
if component_definition.get("type") != model_type.__name__:
Expand Down Expand Up @@ -1188,6 +1201,7 @@ def create_concurrent_cursor_from_perpartition_cursor(
config: Config,
stream_state: MutableMapping[str, Any],
partition_router: PartitionRouter,
stream_state_migrations: Optional[List[Any]] = None,
**kwargs: Any,
) -> ConcurrentPerPartitionCursor:
component_type = component_definition.get("type")
Expand Down Expand Up @@ -1236,8 +1250,10 @@ def create_concurrent_cursor_from_perpartition_cursor(
stream_namespace=stream_namespace,
config=config,
message_repository=NoopMessageRepository(),
stream_state_migrations=stream_state_migrations,
darynaishchenko marked this conversation as resolved.
Show resolved Hide resolved
)
)
stream_state = self.apply_stream_state_migrations(stream_state_migrations, stream_state)

# Return the concurrent cursor and state converter
return ConcurrentPerPartitionCursor(
Expand Down Expand Up @@ -1746,6 +1762,7 @@ def _merge_stream_slicers(
stream_name=model.name or "",
stream_namespace=None,
config=config or {},
stream_state_migrations=model.state_migrations,
)
return (
self._create_component_from_model(model=model.incremental_sync, config=config)
Expand Down
47 changes: 47 additions & 0 deletions unit_tests/sources/declarative/custom_state_migration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
#

from typing import Any, Mapping

from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
from airbyte_cdk.sources.declarative.migrations.state_migration import StateMigration
from airbyte_cdk.sources.types import Config


class CustomStateMigration(StateMigration):
declarative_stream: DeclarativeStream
config: Config

def __init__(self, declarative_stream: DeclarativeStream, config: Config):
self._config = config
self.declarative_stream = declarative_stream
self._cursor = declarative_stream.incremental_sync
self._parameters = declarative_stream.parameters
self._cursor_field = InterpolatedString.create(
self._cursor.cursor_field, parameters=self._parameters
).eval(self._config)

def should_migrate(self, stream_state: Mapping[str, Any]) -> bool:
return True

def migrate(self, stream_state: Mapping[str, Any]) -> Mapping[str, Any]:
if not self.should_migrate(stream_state):
return stream_state
updated_at = stream_state[self._cursor.cursor_field]

migrated_stream_state = {
"states": [
{
"partition": {"type": "type_1"},
"cursor": {self._cursor.cursor_field: updated_at},
},
{
"partition": {"type": "type_2"},
"cursor": {self._cursor.cursor_field: updated_at},
},
]
}

return migrated_stream_state
Original file line number Diff line number Diff line change
Expand Up @@ -3281,6 +3281,126 @@ def test_create_concurrent_cursor_from_datetime_based_cursor(
assert getattr(concurrent_cursor, assertion_field) == expected_value


def test_create_concurrent_cursor_from_datetime_based_cursor_runs_state_migrations():
class DummyStateMigration:
def should_migrate(self, stream_state: Mapping[str, Any]) -> bool:
return True

def migrate(self, stream_state: Mapping[str, Any]) -> Mapping[str, Any]:
updated_at = stream_state["updated_at"]
return {
"states": [
{
"partition": {"type": "type_1"},
"cursor": {"updated_at": updated_at},
},
{
"partition": {"type": "type_2"},
"cursor": {"updated_at": updated_at},
},
]
}

stream_name = "test"
config = {
"start_time": "2024-08-01T00:00:00.000000Z",
"end_time": "2024-09-01T00:00:00.000000Z",
}
stream_state = {"updated_at": "2025-01-01T00:00:00.000000Z"}
connector_builder_factory = ModelToComponentFactory(emit_connector_builder_messages=True)
connector_state_manager = ConnectorStateManager()
cursor_component_definition = {
"type": "DatetimeBasedCursor",
"cursor_field": "updated_at",
"datetime_format": "%Y-%m-%dT%H:%M:%S.%fZ",
"start_datetime": "{{ config['start_time'] }}",
"end_datetime": "{{ config['end_time'] }}",
"partition_field_start": "custom_start",
"partition_field_end": "custom_end",
"step": "P10D",
"cursor_granularity": "PT0.000001S",
"lookback_window": "P3D",
}
concurrent_cursor = (
connector_builder_factory.create_concurrent_cursor_from_datetime_based_cursor(
state_manager=connector_state_manager,
model_type=DatetimeBasedCursorModel,
component_definition=cursor_component_definition,
stream_name=stream_name,
stream_namespace=None,
config=config,
stream_state=stream_state,
stream_state_migrations=[DummyStateMigration()],
)
)
assert concurrent_cursor.state["states"] == [
{"cursor": {"updated_at": stream_state["updated_at"]}, "partition": {"type": "type_1"}},
{"cursor": {"updated_at": stream_state["updated_at"]}, "partition": {"type": "type_2"}},
]


def test_create_concurrent_cursor_from_perpartition_cursor_runs_state_migrations():
class DummyStateMigration:
def should_migrate(self, stream_state: Mapping[str, Any]) -> bool:
return True

def migrate(self, stream_state: Mapping[str, Any]) -> Mapping[str, Any]:
stream_state["lookback_window"] = 10 * 2
return stream_state

state = {
"states": [
{
"partition": {"type": "typ_1"},
"cursor": {"updated_at": "2024-08-01T00:00:00.000000Z"},
}
],
"state": {"updated_at": "2024-08-01T00:00:00.000000Z"},
"lookback_window": 10,
"parent_state": {"parent_test": {"last_updated": "2024-08-01T00:00:00.000000Z"}},
}
config = {
"start_time": "2024-08-01T00:00:00.000000Z",
"end_time": "2024-09-01T00:00:00.000000Z",
}
list_partition_router = ListPartitionRouter(
cursor_field="id",
values=["type_1", "type_2", "type_3"],
config=config,
parameters={},
)
connector_state_manager = ConnectorStateManager()
stream_name = "test"
cursor_component_definition = {
"type": "DatetimeBasedCursor",
"cursor_field": "updated_at",
"datetime_format": "%Y-%m-%dT%H:%M:%S.%fZ",
"start_datetime": "{{ config['start_time'] }}",
"end_datetime": "{{ config['end_time'] }}",
"partition_field_start": "custom_start",
"partition_field_end": "custom_end",
"step": "P10D",
"cursor_granularity": "PT0.000001S",
"lookback_window": "P3D",
}
connector_builder_factory = ModelToComponentFactory(emit_connector_builder_messages=True)
cursor = connector_builder_factory.create_concurrent_cursor_from_perpartition_cursor(
state_manager=connector_state_manager,
model_type=DatetimeBasedCursorModel,
component_definition=cursor_component_definition,
stream_name=stream_name,
stream_namespace=None,
config=config,
stream_state=state,
partition_router=list_partition_router,
stream_state_migrations=[DummyStateMigration()],
)
assert cursor.state["lookback_window"] != 10, "State migration wasn't called"
assert (
cursor.state["lookback_window"] == 20
), "State migration was called, but actual state don't match expected"


def test_create_concurrent_cursor_uses_min_max_datetime_format_if_defined():
"""
Validates a special case for when the start_time.datetime_format and end_time.datetime_format are defined, the date to
Expand Down
Loading
Loading