diff --git a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py index d4ecc0084..96bd67c32 100644 --- a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py +++ b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py @@ -3,7 +3,7 @@ # import logging -from typing import Any, Generic, Iterator, List, Mapping, Optional, Tuple +from typing import Any, Generic, Iterator, List, Mapping, MutableMapping, Optional, Tuple from airbyte_cdk.models import ( AirbyteCatalog, @@ -224,6 +224,7 @@ def _group_streams( stream_state = self._connector_state_manager.get_stream_state( stream_name=declarative_stream.name, namespace=declarative_stream.namespace ) + stream_state = self._migrate_state(declarative_stream, stream_state) retriever = self._get_retriever(declarative_stream, stream_state) @@ -331,6 +332,8 @@ def _group_streams( stream_state = self._connector_state_manager.get_stream_state( stream_name=declarative_stream.name, namespace=declarative_stream.namespace ) + stream_state = self._migrate_state(declarative_stream, stream_state) + partition_router = declarative_stream.retriever.stream_slicer._partition_router perpartition_cursor = ( @@ -521,3 +524,14 @@ def _remove_concurrent_streams_from_catalog( if stream.stream.name not in concurrent_stream_names ] ) + + @staticmethod + def _migrate_state( + declarative_stream: DeclarativeStream, stream_state: MutableMapping[str, Any] + ) -> MutableMapping[str, Any]: + for state_migration in declarative_stream.state_migrations: + if state_migration.should_migrate(stream_state): + # The state variable is expected to be mutable but the migrate method returns an immutable mapping. + stream_state = dict(state_migration.migrate(stream_state)) + + return stream_state diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index a664b8530..c6d69623d 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -934,6 +934,17 @@ def create_concurrency_level( parameters={}, ) + @staticmethod + def apply_stream_state_migrations( + stream_state_migrations: List[Any] | None, stream_state: MutableMapping[str, Any] + ) -> MutableMapping[str, Any]: + if stream_state_migrations: + for state_migration in stream_state_migrations: + if state_migration.should_migrate(stream_state): + # The state variable is expected to be mutable but the migrate method returns an immutable mapping. + stream_state = dict(state_migration.migrate(stream_state)) + return stream_state + def create_concurrent_cursor_from_datetime_based_cursor( self, model_type: Type[BaseModel], @@ -943,6 +954,7 @@ def create_concurrent_cursor_from_datetime_based_cursor( config: Config, message_repository: Optional[MessageRepository] = None, runtime_lookback_window: Optional[datetime.timedelta] = None, + stream_state_migrations: Optional[List[Any]] = None, **kwargs: Any, ) -> ConcurrentCursor: # Per-partition incremental streams can dynamically create child cursors which will pass their current @@ -953,6 +965,7 @@ def create_concurrent_cursor_from_datetime_based_cursor( if "stream_state" not in kwargs else kwargs["stream_state"] ) + stream_state = self.apply_stream_state_migrations(stream_state_migrations, stream_state) component_type = component_definition.get("type") if component_definition.get("type") != model_type.__name__: @@ -1188,6 +1201,7 @@ def create_concurrent_cursor_from_perpartition_cursor( config: Config, stream_state: MutableMapping[str, Any], partition_router: PartitionRouter, + stream_state_migrations: Optional[List[Any]] = None, **kwargs: Any, ) -> ConcurrentPerPartitionCursor: component_type = component_definition.get("type") @@ -1236,8 +1250,10 @@ def create_concurrent_cursor_from_perpartition_cursor( stream_namespace=stream_namespace, config=config, message_repository=NoopMessageRepository(), + stream_state_migrations=stream_state_migrations, ) ) + stream_state = self.apply_stream_state_migrations(stream_state_migrations, stream_state) # Return the concurrent cursor and state converter return ConcurrentPerPartitionCursor( @@ -1746,6 +1762,7 @@ def _merge_stream_slicers( stream_name=model.name or "", stream_namespace=None, config=config or {}, + stream_state_migrations=model.state_migrations, ) return ( self._create_component_from_model(model=model.incremental_sync, config=config) diff --git a/unit_tests/sources/declarative/custom_state_migration.py b/unit_tests/sources/declarative/custom_state_migration.py new file mode 100644 index 000000000..86ca4a5c4 --- /dev/null +++ b/unit_tests/sources/declarative/custom_state_migration.py @@ -0,0 +1,47 @@ +# +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. +# + +from typing import Any, Mapping + +from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream +from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString +from airbyte_cdk.sources.declarative.migrations.state_migration import StateMigration +from airbyte_cdk.sources.types import Config + + +class CustomStateMigration(StateMigration): + declarative_stream: DeclarativeStream + config: Config + + def __init__(self, declarative_stream: DeclarativeStream, config: Config): + self._config = config + self.declarative_stream = declarative_stream + self._cursor = declarative_stream.incremental_sync + self._parameters = declarative_stream.parameters + self._cursor_field = InterpolatedString.create( + self._cursor.cursor_field, parameters=self._parameters + ).eval(self._config) + + def should_migrate(self, stream_state: Mapping[str, Any]) -> bool: + return True + + def migrate(self, stream_state: Mapping[str, Any]) -> Mapping[str, Any]: + if not self.should_migrate(stream_state): + return stream_state + updated_at = stream_state[self._cursor.cursor_field] + + migrated_stream_state = { + "states": [ + { + "partition": {"type": "type_1"}, + "cursor": {self._cursor.cursor_field: updated_at}, + }, + { + "partition": {"type": "type_2"}, + "cursor": {self._cursor.cursor_field: updated_at}, + }, + ] + } + + return migrated_stream_state diff --git a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py index 43564a5c8..32a73f364 100644 --- a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py +++ b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py @@ -3281,6 +3281,126 @@ def test_create_concurrent_cursor_from_datetime_based_cursor( assert getattr(concurrent_cursor, assertion_field) == expected_value +def test_create_concurrent_cursor_from_datetime_based_cursor_runs_state_migrations(): + class DummyStateMigration: + def should_migrate(self, stream_state: Mapping[str, Any]) -> bool: + return True + + def migrate(self, stream_state: Mapping[str, Any]) -> Mapping[str, Any]: + updated_at = stream_state["updated_at"] + return { + "states": [ + { + "partition": {"type": "type_1"}, + "cursor": {"updated_at": updated_at}, + }, + { + "partition": {"type": "type_2"}, + "cursor": {"updated_at": updated_at}, + }, + ] + } + + stream_name = "test" + config = { + "start_time": "2024-08-01T00:00:00.000000Z", + "end_time": "2024-09-01T00:00:00.000000Z", + } + stream_state = {"updated_at": "2025-01-01T00:00:00.000000Z"} + connector_builder_factory = ModelToComponentFactory(emit_connector_builder_messages=True) + connector_state_manager = ConnectorStateManager() + cursor_component_definition = { + "type": "DatetimeBasedCursor", + "cursor_field": "updated_at", + "datetime_format": "%Y-%m-%dT%H:%M:%S.%fZ", + "start_datetime": "{{ config['start_time'] }}", + "end_datetime": "{{ config['end_time'] }}", + "partition_field_start": "custom_start", + "partition_field_end": "custom_end", + "step": "P10D", + "cursor_granularity": "PT0.000001S", + "lookback_window": "P3D", + } + concurrent_cursor = ( + connector_builder_factory.create_concurrent_cursor_from_datetime_based_cursor( + state_manager=connector_state_manager, + model_type=DatetimeBasedCursorModel, + component_definition=cursor_component_definition, + stream_name=stream_name, + stream_namespace=None, + config=config, + stream_state=stream_state, + stream_state_migrations=[DummyStateMigration()], + ) + ) + assert concurrent_cursor.state["states"] == [ + {"cursor": {"updated_at": stream_state["updated_at"]}, "partition": {"type": "type_1"}}, + {"cursor": {"updated_at": stream_state["updated_at"]}, "partition": {"type": "type_2"}}, + ] + + +def test_create_concurrent_cursor_from_perpartition_cursor_runs_state_migrations(): + class DummyStateMigration: + def should_migrate(self, stream_state: Mapping[str, Any]) -> bool: + return True + + def migrate(self, stream_state: Mapping[str, Any]) -> Mapping[str, Any]: + stream_state["lookback_window"] = 10 * 2 + return stream_state + + state = { + "states": [ + { + "partition": {"type": "typ_1"}, + "cursor": {"updated_at": "2024-08-01T00:00:00.000000Z"}, + } + ], + "state": {"updated_at": "2024-08-01T00:00:00.000000Z"}, + "lookback_window": 10, + "parent_state": {"parent_test": {"last_updated": "2024-08-01T00:00:00.000000Z"}}, + } + config = { + "start_time": "2024-08-01T00:00:00.000000Z", + "end_time": "2024-09-01T00:00:00.000000Z", + } + list_partition_router = ListPartitionRouter( + cursor_field="id", + values=["type_1", "type_2", "type_3"], + config=config, + parameters={}, + ) + connector_state_manager = ConnectorStateManager() + stream_name = "test" + cursor_component_definition = { + "type": "DatetimeBasedCursor", + "cursor_field": "updated_at", + "datetime_format": "%Y-%m-%dT%H:%M:%S.%fZ", + "start_datetime": "{{ config['start_time'] }}", + "end_datetime": "{{ config['end_time'] }}", + "partition_field_start": "custom_start", + "partition_field_end": "custom_end", + "step": "P10D", + "cursor_granularity": "PT0.000001S", + "lookback_window": "P3D", + } + connector_builder_factory = ModelToComponentFactory(emit_connector_builder_messages=True) + cursor = connector_builder_factory.create_concurrent_cursor_from_perpartition_cursor( + state_manager=connector_state_manager, + model_type=DatetimeBasedCursorModel, + component_definition=cursor_component_definition, + stream_name=stream_name, + stream_namespace=None, + config=config, + stream_state=state, + partition_router=list_partition_router, + stream_state_migrations=[DummyStateMigration()], + ) + assert cursor.state["lookback_window"] != 10, "State migration wasn't called" + assert ( + cursor.state["lookback_window"] == 20 + ), "State migration was called, but actual state don't match expected" + + def test_create_concurrent_cursor_uses_min_max_datetime_format_if_defined(): """ Validates a special case for when the start_time.datetime_format and end_time.datetime_format are defined, the date to diff --git a/unit_tests/sources/declarative/test_concurrent_declarative_source.py b/unit_tests/sources/declarative/test_concurrent_declarative_source.py index 1877e11bb..71874248d 100644 --- a/unit_tests/sources/declarative/test_concurrent_declarative_source.py +++ b/unit_tests/sources/declarative/test_concurrent_declarative_source.py @@ -1231,6 +1231,157 @@ def test_read_with_concurrent_and_synchronous_streams_with_sequential_state(): assert len(party_members_skills_records) == 9 +def test_concurrent_declarative_source_runs_state_migrations_provided_in_manifest(): + manifest = { + "version": "5.0.0", + "definitions": { + "selector": { + "type": "RecordSelector", + "extractor": {"type": "DpathExtractor", "field_path": []}, + }, + "requester": { + "type": "HttpRequester", + "url_base": "https://persona.metaverse.com", + "http_method": "GET", + "authenticator": { + "type": "BasicHttpAuthenticator", + "username": "{{ config['api_key'] }}", + "password": "{{ config['secret_key'] }}", + }, + "error_handler": { + "type": "DefaultErrorHandler", + "response_filters": [ + { + "http_codes": [403], + "action": "FAIL", + "failure_type": "config_error", + "error_message": "Access denied due to lack of permission or invalid API/Secret key or wrong data region.", + }, + { + "http_codes": [404], + "action": "IGNORE", + "error_message": "No data available for the time range requested.", + }, + ], + }, + }, + "retriever": { + "type": "SimpleRetriever", + "record_selector": {"$ref": "#/definitions/selector"}, + "paginator": {"type": "NoPagination"}, + "requester": {"$ref": "#/definitions/requester"}, + }, + "incremental_cursor": { + "type": "DatetimeBasedCursor", + "start_datetime": { + "datetime": "{{ format_datetime(config['start_date'], '%Y-%m-%d') }}" + }, + "end_datetime": {"datetime": "{{ now_utc().strftime('%Y-%m-%d') }}"}, + "datetime_format": "%Y-%m-%d", + "cursor_datetime_formats": ["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"], + "cursor_granularity": "P1D", + "step": "P15D", + "cursor_field": "updated_at", + "lookback_window": "P5D", + "start_time_option": { + "type": "RequestOption", + "field_name": "start", + "inject_into": "request_parameter", + }, + "end_time_option": { + "type": "RequestOption", + "field_name": "end", + "inject_into": "request_parameter", + }, + }, + "base_stream": {"retriever": {"$ref": "#/definitions/retriever"}}, + "base_incremental_stream": { + "retriever": { + "$ref": "#/definitions/retriever", + "requester": {"$ref": "#/definitions/requester"}, + }, + "incremental_sync": {"$ref": "#/definitions/incremental_cursor"}, + }, + "party_members_stream": { + "$ref": "#/definitions/base_incremental_stream", + "retriever": { + "$ref": "#/definitions/base_incremental_stream/retriever", + "requester": { + "$ref": "#/definitions/requester", + "request_parameters": {"filter": "{{stream_partition['type']}}"}, + }, + "record_selector": {"$ref": "#/definitions/selector"}, + "partition_router": [ + { + "type": "ListPartitionRouter", + "values": ["type_1", "type_2"], + "cursor_field": "type", + } + ], + }, + "$parameters": { + "name": "party_members", + "primary_key": "id", + "path": "/party_members", + }, + "state_migrations": [ + { + "type": "CustomStateMigration", + "class_name": "unit_tests.sources.declarative.custom_state_migration.CustomStateMigration", + } + ], + "schema_loader": { + "type": "InlineSchemaLoader", + "schema": { + "$schema": "https://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "id": { + "description": "The identifier", + "type": ["null", "string"], + }, + "name": { + "description": "The name of the party member", + "type": ["null", "string"], + }, + }, + }, + }, + }, + }, + "streams": [ + "#/definitions/party_members_stream", + ], + "check": {"stream_names": ["party_members", "locations"]}, + "concurrency_level": { + "type": "ConcurrencyLevel", + "default_concurrency": "{{ config['num_workers'] or 10 }}", + "max_concurrency": 25, + }, + } + state_blob = AirbyteStateBlob(updated_at="2024-08-21") + state = [ + AirbyteStateMessage( + type=AirbyteStateType.STREAM, + stream=AirbyteStreamState( + stream_descriptor=StreamDescriptor(name="party_members", namespace=None), + stream_state=state_blob, + ), + ), + ] + source = ConcurrentDeclarativeSource( + source_config=manifest, config=_CONFIG, catalog=_CATALOG, state=state + ) + concurrent_streams, synchronous_streams = source._group_streams(_CONFIG) + assert ( + concurrent_streams[0].cursor.state.get("state") != state_blob.__dict__ + ), "State was not migrated." + assert concurrent_streams[0].cursor.state.get("states") == [ + {"cursor": {"updated_at": "2024-08-21"}, "partition": {"type": "type_1"}}, + {"cursor": {"updated_at": "2024-08-21"}, "partition": {"type": "type_2"}}, + ], "State was migrated, but actual state don't match expected" + + @freezegun.freeze_time(_NOW) @patch( "airbyte_cdk.sources.streams.concurrent.state_converters.abstract_stream_state_converter.AbstractStreamStateConverter.__init__",