Skip to content

Commit

Permalink
fix: (low code)run state migrations for concurrent streams (#316)
Browse files Browse the repository at this point in the history
Co-authored-by: octavia-squidington-iii <[email protected]>
  • Loading branch information
darynaishchenko and octavia-squidington-iii authored Feb 10, 2025
1 parent 6260248 commit 74631d8
Show file tree
Hide file tree
Showing 5 changed files with 350 additions and 1 deletion.
16 changes: 15 additions & 1 deletion airbyte_cdk/sources/declarative/concurrent_declarative_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#

import logging
from typing import Any, Generic, Iterator, List, Mapping, Optional, Tuple
from typing import Any, Generic, Iterator, List, Mapping, MutableMapping, Optional, Tuple

from airbyte_cdk.models import (
AirbyteCatalog,
Expand Down Expand Up @@ -224,6 +224,7 @@ def _group_streams(
stream_state = self._connector_state_manager.get_stream_state(
stream_name=declarative_stream.name, namespace=declarative_stream.namespace
)
stream_state = self._migrate_state(declarative_stream, stream_state)

retriever = self._get_retriever(declarative_stream, stream_state)

Expand Down Expand Up @@ -331,6 +332,8 @@ def _group_streams(
stream_state = self._connector_state_manager.get_stream_state(
stream_name=declarative_stream.name, namespace=declarative_stream.namespace
)
stream_state = self._migrate_state(declarative_stream, stream_state)

partition_router = declarative_stream.retriever.stream_slicer._partition_router

perpartition_cursor = (
Expand Down Expand Up @@ -521,3 +524,14 @@ def _remove_concurrent_streams_from_catalog(
if stream.stream.name not in concurrent_stream_names
]
)

@staticmethod
def _migrate_state(
declarative_stream: DeclarativeStream, stream_state: MutableMapping[str, Any]
) -> MutableMapping[str, Any]:
for state_migration in declarative_stream.state_migrations:
if state_migration.should_migrate(stream_state):
# The state variable is expected to be mutable but the migrate method returns an immutable mapping.
stream_state = dict(state_migration.migrate(stream_state))

return stream_state
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,17 @@ def create_concurrency_level(
parameters={},
)

@staticmethod
def apply_stream_state_migrations(
stream_state_migrations: List[Any] | None, stream_state: MutableMapping[str, Any]
) -> MutableMapping[str, Any]:
if stream_state_migrations:
for state_migration in stream_state_migrations:
if state_migration.should_migrate(stream_state):
# The state variable is expected to be mutable but the migrate method returns an immutable mapping.
stream_state = dict(state_migration.migrate(stream_state))
return stream_state

def create_concurrent_cursor_from_datetime_based_cursor(
self,
model_type: Type[BaseModel],
Expand All @@ -943,6 +954,7 @@ def create_concurrent_cursor_from_datetime_based_cursor(
config: Config,
message_repository: Optional[MessageRepository] = None,
runtime_lookback_window: Optional[datetime.timedelta] = None,
stream_state_migrations: Optional[List[Any]] = None,
**kwargs: Any,
) -> ConcurrentCursor:
# Per-partition incremental streams can dynamically create child cursors which will pass their current
Expand All @@ -953,6 +965,7 @@ def create_concurrent_cursor_from_datetime_based_cursor(
if "stream_state" not in kwargs
else kwargs["stream_state"]
)
stream_state = self.apply_stream_state_migrations(stream_state_migrations, stream_state)

component_type = component_definition.get("type")
if component_definition.get("type") != model_type.__name__:
Expand Down Expand Up @@ -1188,6 +1201,7 @@ def create_concurrent_cursor_from_perpartition_cursor(
config: Config,
stream_state: MutableMapping[str, Any],
partition_router: PartitionRouter,
stream_state_migrations: Optional[List[Any]] = None,
**kwargs: Any,
) -> ConcurrentPerPartitionCursor:
component_type = component_definition.get("type")
Expand Down Expand Up @@ -1236,8 +1250,10 @@ def create_concurrent_cursor_from_perpartition_cursor(
stream_namespace=stream_namespace,
config=config,
message_repository=NoopMessageRepository(),
stream_state_migrations=stream_state_migrations,
)
)
stream_state = self.apply_stream_state_migrations(stream_state_migrations, stream_state)

# Return the concurrent cursor and state converter
return ConcurrentPerPartitionCursor(
Expand Down Expand Up @@ -1746,6 +1762,7 @@ def _merge_stream_slicers(
stream_name=model.name or "",
stream_namespace=None,
config=config or {},
stream_state_migrations=model.state_migrations,
)
return (
self._create_component_from_model(model=model.incremental_sync, config=config)
Expand Down
47 changes: 47 additions & 0 deletions unit_tests/sources/declarative/custom_state_migration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
#

from typing import Any, Mapping

from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
from airbyte_cdk.sources.declarative.migrations.state_migration import StateMigration
from airbyte_cdk.sources.types import Config


class CustomStateMigration(StateMigration):
declarative_stream: DeclarativeStream
config: Config

def __init__(self, declarative_stream: DeclarativeStream, config: Config):
self._config = config
self.declarative_stream = declarative_stream
self._cursor = declarative_stream.incremental_sync
self._parameters = declarative_stream.parameters
self._cursor_field = InterpolatedString.create(
self._cursor.cursor_field, parameters=self._parameters
).eval(self._config)

def should_migrate(self, stream_state: Mapping[str, Any]) -> bool:
return True

def migrate(self, stream_state: Mapping[str, Any]) -> Mapping[str, Any]:
if not self.should_migrate(stream_state):
return stream_state
updated_at = stream_state[self._cursor.cursor_field]

migrated_stream_state = {
"states": [
{
"partition": {"type": "type_1"},
"cursor": {self._cursor.cursor_field: updated_at},
},
{
"partition": {"type": "type_2"},
"cursor": {self._cursor.cursor_field: updated_at},
},
]
}

return migrated_stream_state
Original file line number Diff line number Diff line change
Expand Up @@ -3281,6 +3281,126 @@ def test_create_concurrent_cursor_from_datetime_based_cursor(
assert getattr(concurrent_cursor, assertion_field) == expected_value


def test_create_concurrent_cursor_from_datetime_based_cursor_runs_state_migrations():
class DummyStateMigration:
def should_migrate(self, stream_state: Mapping[str, Any]) -> bool:
return True

def migrate(self, stream_state: Mapping[str, Any]) -> Mapping[str, Any]:
updated_at = stream_state["updated_at"]
return {
"states": [
{
"partition": {"type": "type_1"},
"cursor": {"updated_at": updated_at},
},
{
"partition": {"type": "type_2"},
"cursor": {"updated_at": updated_at},
},
]
}

stream_name = "test"
config = {
"start_time": "2024-08-01T00:00:00.000000Z",
"end_time": "2024-09-01T00:00:00.000000Z",
}
stream_state = {"updated_at": "2025-01-01T00:00:00.000000Z"}
connector_builder_factory = ModelToComponentFactory(emit_connector_builder_messages=True)
connector_state_manager = ConnectorStateManager()
cursor_component_definition = {
"type": "DatetimeBasedCursor",
"cursor_field": "updated_at",
"datetime_format": "%Y-%m-%dT%H:%M:%S.%fZ",
"start_datetime": "{{ config['start_time'] }}",
"end_datetime": "{{ config['end_time'] }}",
"partition_field_start": "custom_start",
"partition_field_end": "custom_end",
"step": "P10D",
"cursor_granularity": "PT0.000001S",
"lookback_window": "P3D",
}
concurrent_cursor = (
connector_builder_factory.create_concurrent_cursor_from_datetime_based_cursor(
state_manager=connector_state_manager,
model_type=DatetimeBasedCursorModel,
component_definition=cursor_component_definition,
stream_name=stream_name,
stream_namespace=None,
config=config,
stream_state=stream_state,
stream_state_migrations=[DummyStateMigration()],
)
)
assert concurrent_cursor.state["states"] == [
{"cursor": {"updated_at": stream_state["updated_at"]}, "partition": {"type": "type_1"}},
{"cursor": {"updated_at": stream_state["updated_at"]}, "partition": {"type": "type_2"}},
]


def test_create_concurrent_cursor_from_perpartition_cursor_runs_state_migrations():
class DummyStateMigration:
def should_migrate(self, stream_state: Mapping[str, Any]) -> bool:
return True

def migrate(self, stream_state: Mapping[str, Any]) -> Mapping[str, Any]:
stream_state["lookback_window"] = 10 * 2
return stream_state

state = {
"states": [
{
"partition": {"type": "typ_1"},
"cursor": {"updated_at": "2024-08-01T00:00:00.000000Z"},
}
],
"state": {"updated_at": "2024-08-01T00:00:00.000000Z"},
"lookback_window": 10,
"parent_state": {"parent_test": {"last_updated": "2024-08-01T00:00:00.000000Z"}},
}
config = {
"start_time": "2024-08-01T00:00:00.000000Z",
"end_time": "2024-09-01T00:00:00.000000Z",
}
list_partition_router = ListPartitionRouter(
cursor_field="id",
values=["type_1", "type_2", "type_3"],
config=config,
parameters={},
)
connector_state_manager = ConnectorStateManager()
stream_name = "test"
cursor_component_definition = {
"type": "DatetimeBasedCursor",
"cursor_field": "updated_at",
"datetime_format": "%Y-%m-%dT%H:%M:%S.%fZ",
"start_datetime": "{{ config['start_time'] }}",
"end_datetime": "{{ config['end_time'] }}",
"partition_field_start": "custom_start",
"partition_field_end": "custom_end",
"step": "P10D",
"cursor_granularity": "PT0.000001S",
"lookback_window": "P3D",
}
connector_builder_factory = ModelToComponentFactory(emit_connector_builder_messages=True)
cursor = connector_builder_factory.create_concurrent_cursor_from_perpartition_cursor(
state_manager=connector_state_manager,
model_type=DatetimeBasedCursorModel,
component_definition=cursor_component_definition,
stream_name=stream_name,
stream_namespace=None,
config=config,
stream_state=state,
partition_router=list_partition_router,
stream_state_migrations=[DummyStateMigration()],
)
assert cursor.state["lookback_window"] != 10, "State migration wasn't called"
assert (
cursor.state["lookback_window"] == 20
), "State migration was called, but actual state don't match expected"


def test_create_concurrent_cursor_uses_min_max_datetime_format_if_defined():
"""
Validates a special case for when the start_time.datetime_format and end_time.datetime_format are defined, the date to
Expand Down
Loading

0 comments on commit 74631d8

Please sign in to comment.