Skip to content

Commit 74631d8

Browse files
darynaishchenkooctavia-squidington-iii
and
octavia-squidington-iii
authored
fix: (low code)run state migrations for concurrent streams (#316)
Co-authored-by: octavia-squidington-iii <[email protected]>
1 parent 6260248 commit 74631d8

File tree

5 files changed

+350
-1
lines changed

5 files changed

+350
-1
lines changed

airbyte_cdk/sources/declarative/concurrent_declarative_source.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#
44

55
import logging
6-
from typing import Any, Generic, Iterator, List, Mapping, Optional, Tuple
6+
from typing import Any, Generic, Iterator, List, Mapping, MutableMapping, Optional, Tuple
77

88
from airbyte_cdk.models import (
99
AirbyteCatalog,
@@ -224,6 +224,7 @@ def _group_streams(
224224
stream_state = self._connector_state_manager.get_stream_state(
225225
stream_name=declarative_stream.name, namespace=declarative_stream.namespace
226226
)
227+
stream_state = self._migrate_state(declarative_stream, stream_state)
227228

228229
retriever = self._get_retriever(declarative_stream, stream_state)
229230

@@ -331,6 +332,8 @@ def _group_streams(
331332
stream_state = self._connector_state_manager.get_stream_state(
332333
stream_name=declarative_stream.name, namespace=declarative_stream.namespace
333334
)
335+
stream_state = self._migrate_state(declarative_stream, stream_state)
336+
334337
partition_router = declarative_stream.retriever.stream_slicer._partition_router
335338

336339
perpartition_cursor = (
@@ -521,3 +524,14 @@ def _remove_concurrent_streams_from_catalog(
521524
if stream.stream.name not in concurrent_stream_names
522525
]
523526
)
527+
528+
@staticmethod
529+
def _migrate_state(
530+
declarative_stream: DeclarativeStream, stream_state: MutableMapping[str, Any]
531+
) -> MutableMapping[str, Any]:
532+
for state_migration in declarative_stream.state_migrations:
533+
if state_migration.should_migrate(stream_state):
534+
# The state variable is expected to be mutable but the migrate method returns an immutable mapping.
535+
stream_state = dict(state_migration.migrate(stream_state))
536+
537+
return stream_state

airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py

+17
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,17 @@ def create_concurrency_level(
934934
parameters={},
935935
)
936936

937+
@staticmethod
938+
def apply_stream_state_migrations(
939+
stream_state_migrations: List[Any] | None, stream_state: MutableMapping[str, Any]
940+
) -> MutableMapping[str, Any]:
941+
if stream_state_migrations:
942+
for state_migration in stream_state_migrations:
943+
if state_migration.should_migrate(stream_state):
944+
# The state variable is expected to be mutable but the migrate method returns an immutable mapping.
945+
stream_state = dict(state_migration.migrate(stream_state))
946+
return stream_state
947+
937948
def create_concurrent_cursor_from_datetime_based_cursor(
938949
self,
939950
model_type: Type[BaseModel],
@@ -943,6 +954,7 @@ def create_concurrent_cursor_from_datetime_based_cursor(
943954
config: Config,
944955
message_repository: Optional[MessageRepository] = None,
945956
runtime_lookback_window: Optional[datetime.timedelta] = None,
957+
stream_state_migrations: Optional[List[Any]] = None,
946958
**kwargs: Any,
947959
) -> ConcurrentCursor:
948960
# Per-partition incremental streams can dynamically create child cursors which will pass their current
@@ -953,6 +965,7 @@ def create_concurrent_cursor_from_datetime_based_cursor(
953965
if "stream_state" not in kwargs
954966
else kwargs["stream_state"]
955967
)
968+
stream_state = self.apply_stream_state_migrations(stream_state_migrations, stream_state)
956969

957970
component_type = component_definition.get("type")
958971
if component_definition.get("type") != model_type.__name__:
@@ -1188,6 +1201,7 @@ def create_concurrent_cursor_from_perpartition_cursor(
11881201
config: Config,
11891202
stream_state: MutableMapping[str, Any],
11901203
partition_router: PartitionRouter,
1204+
stream_state_migrations: Optional[List[Any]] = None,
11911205
**kwargs: Any,
11921206
) -> ConcurrentPerPartitionCursor:
11931207
component_type = component_definition.get("type")
@@ -1236,8 +1250,10 @@ def create_concurrent_cursor_from_perpartition_cursor(
12361250
stream_namespace=stream_namespace,
12371251
config=config,
12381252
message_repository=NoopMessageRepository(),
1253+
stream_state_migrations=stream_state_migrations,
12391254
)
12401255
)
1256+
stream_state = self.apply_stream_state_migrations(stream_state_migrations, stream_state)
12411257

12421258
# Return the concurrent cursor and state converter
12431259
return ConcurrentPerPartitionCursor(
@@ -1746,6 +1762,7 @@ def _merge_stream_slicers(
17461762
stream_name=model.name or "",
17471763
stream_namespace=None,
17481764
config=config or {},
1765+
stream_state_migrations=model.state_migrations,
17491766
)
17501767
return (
17511768
self._create_component_from_model(model=model.incremental_sync, config=config)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#
2+
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
3+
#
4+
5+
from typing import Any, Mapping
6+
7+
from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream
8+
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
9+
from airbyte_cdk.sources.declarative.migrations.state_migration import StateMigration
10+
from airbyte_cdk.sources.types import Config
11+
12+
13+
class CustomStateMigration(StateMigration):
14+
declarative_stream: DeclarativeStream
15+
config: Config
16+
17+
def __init__(self, declarative_stream: DeclarativeStream, config: Config):
18+
self._config = config
19+
self.declarative_stream = declarative_stream
20+
self._cursor = declarative_stream.incremental_sync
21+
self._parameters = declarative_stream.parameters
22+
self._cursor_field = InterpolatedString.create(
23+
self._cursor.cursor_field, parameters=self._parameters
24+
).eval(self._config)
25+
26+
def should_migrate(self, stream_state: Mapping[str, Any]) -> bool:
27+
return True
28+
29+
def migrate(self, stream_state: Mapping[str, Any]) -> Mapping[str, Any]:
30+
if not self.should_migrate(stream_state):
31+
return stream_state
32+
updated_at = stream_state[self._cursor.cursor_field]
33+
34+
migrated_stream_state = {
35+
"states": [
36+
{
37+
"partition": {"type": "type_1"},
38+
"cursor": {self._cursor.cursor_field: updated_at},
39+
},
40+
{
41+
"partition": {"type": "type_2"},
42+
"cursor": {self._cursor.cursor_field: updated_at},
43+
},
44+
]
45+
}
46+
47+
return migrated_stream_state

unit_tests/sources/declarative/parsers/test_model_to_component_factory.py

+120
Original file line numberDiff line numberDiff line change
@@ -3281,6 +3281,126 @@ def test_create_concurrent_cursor_from_datetime_based_cursor(
32813281
assert getattr(concurrent_cursor, assertion_field) == expected_value
32823282

32833283

3284+
def test_create_concurrent_cursor_from_datetime_based_cursor_runs_state_migrations():
3285+
class DummyStateMigration:
3286+
def should_migrate(self, stream_state: Mapping[str, Any]) -> bool:
3287+
return True
3288+
3289+
def migrate(self, stream_state: Mapping[str, Any]) -> Mapping[str, Any]:
3290+
updated_at = stream_state["updated_at"]
3291+
return {
3292+
"states": [
3293+
{
3294+
"partition": {"type": "type_1"},
3295+
"cursor": {"updated_at": updated_at},
3296+
},
3297+
{
3298+
"partition": {"type": "type_2"},
3299+
"cursor": {"updated_at": updated_at},
3300+
},
3301+
]
3302+
}
3303+
3304+
stream_name = "test"
3305+
config = {
3306+
"start_time": "2024-08-01T00:00:00.000000Z",
3307+
"end_time": "2024-09-01T00:00:00.000000Z",
3308+
}
3309+
stream_state = {"updated_at": "2025-01-01T00:00:00.000000Z"}
3310+
connector_builder_factory = ModelToComponentFactory(emit_connector_builder_messages=True)
3311+
connector_state_manager = ConnectorStateManager()
3312+
cursor_component_definition = {
3313+
"type": "DatetimeBasedCursor",
3314+
"cursor_field": "updated_at",
3315+
"datetime_format": "%Y-%m-%dT%H:%M:%S.%fZ",
3316+
"start_datetime": "{{ config['start_time'] }}",
3317+
"end_datetime": "{{ config['end_time'] }}",
3318+
"partition_field_start": "custom_start",
3319+
"partition_field_end": "custom_end",
3320+
"step": "P10D",
3321+
"cursor_granularity": "PT0.000001S",
3322+
"lookback_window": "P3D",
3323+
}
3324+
concurrent_cursor = (
3325+
connector_builder_factory.create_concurrent_cursor_from_datetime_based_cursor(
3326+
state_manager=connector_state_manager,
3327+
model_type=DatetimeBasedCursorModel,
3328+
component_definition=cursor_component_definition,
3329+
stream_name=stream_name,
3330+
stream_namespace=None,
3331+
config=config,
3332+
stream_state=stream_state,
3333+
stream_state_migrations=[DummyStateMigration()],
3334+
)
3335+
)
3336+
assert concurrent_cursor.state["states"] == [
3337+
{"cursor": {"updated_at": stream_state["updated_at"]}, "partition": {"type": "type_1"}},
3338+
{"cursor": {"updated_at": stream_state["updated_at"]}, "partition": {"type": "type_2"}},
3339+
]
3340+
3341+
3342+
def test_create_concurrent_cursor_from_perpartition_cursor_runs_state_migrations():
3343+
class DummyStateMigration:
3344+
def should_migrate(self, stream_state: Mapping[str, Any]) -> bool:
3345+
return True
3346+
3347+
def migrate(self, stream_state: Mapping[str, Any]) -> Mapping[str, Any]:
3348+
stream_state["lookback_window"] = 10 * 2
3349+
return stream_state
3350+
3351+
state = {
3352+
"states": [
3353+
{
3354+
"partition": {"type": "typ_1"},
3355+
"cursor": {"updated_at": "2024-08-01T00:00:00.000000Z"},
3356+
}
3357+
],
3358+
"state": {"updated_at": "2024-08-01T00:00:00.000000Z"},
3359+
"lookback_window": 10,
3360+
"parent_state": {"parent_test": {"last_updated": "2024-08-01T00:00:00.000000Z"}},
3361+
}
3362+
config = {
3363+
"start_time": "2024-08-01T00:00:00.000000Z",
3364+
"end_time": "2024-09-01T00:00:00.000000Z",
3365+
}
3366+
list_partition_router = ListPartitionRouter(
3367+
cursor_field="id",
3368+
values=["type_1", "type_2", "type_3"],
3369+
config=config,
3370+
parameters={},
3371+
)
3372+
connector_state_manager = ConnectorStateManager()
3373+
stream_name = "test"
3374+
cursor_component_definition = {
3375+
"type": "DatetimeBasedCursor",
3376+
"cursor_field": "updated_at",
3377+
"datetime_format": "%Y-%m-%dT%H:%M:%S.%fZ",
3378+
"start_datetime": "{{ config['start_time'] }}",
3379+
"end_datetime": "{{ config['end_time'] }}",
3380+
"partition_field_start": "custom_start",
3381+
"partition_field_end": "custom_end",
3382+
"step": "P10D",
3383+
"cursor_granularity": "PT0.000001S",
3384+
"lookback_window": "P3D",
3385+
}
3386+
connector_builder_factory = ModelToComponentFactory(emit_connector_builder_messages=True)
3387+
cursor = connector_builder_factory.create_concurrent_cursor_from_perpartition_cursor(
3388+
state_manager=connector_state_manager,
3389+
model_type=DatetimeBasedCursorModel,
3390+
component_definition=cursor_component_definition,
3391+
stream_name=stream_name,
3392+
stream_namespace=None,
3393+
config=config,
3394+
stream_state=state,
3395+
partition_router=list_partition_router,
3396+
stream_state_migrations=[DummyStateMigration()],
3397+
)
3398+
assert cursor.state["lookback_window"] != 10, "State migration wasn't called"
3399+
assert (
3400+
cursor.state["lookback_window"] == 20
3401+
), "State migration was called, but actual state don't match expected"
3402+
3403+
32843404
def test_create_concurrent_cursor_uses_min_max_datetime_format_if_defined():
32853405
"""
32863406
Validates a special case for when the start_time.datetime_format and end_time.datetime_format are defined, the date to

0 commit comments

Comments
 (0)