From a6d55bee9fc0e7277cf695f387cf1edbe532000d Mon Sep 17 00:00:00 2001 From: "Aaron (\"AJ\") Steers" Date: Thu, 30 Jan 2025 09:58:29 -0800 Subject: [PATCH 01/12] fix: make new datetime parse functions more permissive (#296) Co-authored-by: octavia-squidington-iii --- airbyte_cdk/utils/datetime_helpers.py | 114 +++++++++------------ unit_tests/utils/test_datetime_helpers.py | 119 +++++++++++++--------- 2 files changed, 117 insertions(+), 116 deletions(-) diff --git a/airbyte_cdk/utils/datetime_helpers.py b/airbyte_cdk/utils/datetime_helpers.py index 008a18d86..99cf1ad23 100644 --- a/airbyte_cdk/utils/datetime_helpers.py +++ b/airbyte_cdk/utils/datetime_helpers.py @@ -76,8 +76,8 @@ assert ab_datetime_try_parse("2023-03-14T15:09:26Z") # Basic UTC format assert ab_datetime_try_parse("2023-03-14T15:09:26-04:00") # With timezone offset assert ab_datetime_try_parse("2023-03-14T15:09:26+00:00") # With explicit UTC offset -assert not ab_datetime_try_parse("2023-03-14 15:09:26Z") # Invalid: missing T delimiter -assert not ab_datetime_try_parse("foo") # Invalid: not a datetime +assert ab_datetime_try_parse("2023-03-14 15:09:26Z") # Missing T delimiter but still parsable +assert not ab_datetime_try_parse("foo") # Invalid: not parsable, returns `None` ``` """ @@ -138,6 +138,14 @@ def from_datetime(cls, dt: datetime) -> "AirbyteDateTime": dt.tzinfo or timezone.utc, ) + def to_datetime(self) -> datetime: + """Converts this AirbyteDateTime to a standard datetime object. + + Today, this just returns `self` because AirbyteDateTime is a subclass of `datetime`. + In the future, we may modify our internal representation to use a different base class. + """ + return self + def __str__(self) -> str: """Returns the datetime in ISO8601/RFC3339 format with 'T' delimiter. @@ -148,12 +156,7 @@ def __str__(self) -> str: str: ISO8601/RFC3339 formatted string. """ aware_self = self if self.tzinfo else self.replace(tzinfo=timezone.utc) - base = self.strftime("%Y-%m-%dT%H:%M:%S") - if self.microsecond: - base = f"{base}.{self.microsecond:06d}" - # Format timezone as ±HH:MM - offset = aware_self.strftime("%z") - return f"{base}{offset[:3]}:{offset[3:]}" + return aware_self.isoformat(sep="T", timespec="auto") def __repr__(self) -> str: """Returns the same string representation as __str__ for consistency. @@ -358,15 +361,15 @@ def ab_datetime_now() -> AirbyteDateTime: def ab_datetime_parse(dt_str: str | int) -> AirbyteDateTime: """Parses a datetime string or timestamp into an AirbyteDateTime with timezone awareness. - Previously named: parse() + This implementation is as flexible as possible to handle various datetime formats. + Always returns a timezone-aware datetime (defaults to UTC if no timezone specified). Handles: - - ISO8601/RFC3339 format strings (with 'T' delimiter) + - ISO8601/RFC3339 format strings (with ' ' or 'T' delimiter) - Unix timestamps (as integers or strings) - Date-only strings (YYYY-MM-DD) - Timezone-aware formats (+00:00 for UTC, or ±HH:MM offset) - - Always returns a timezone-aware datetime (defaults to UTC if no timezone specified). + - Anything that can be parsed by `dateutil.parser.parse()` Args: dt_str: A datetime string in ISO8601/RFC3339 format, Unix timestamp (int/str), @@ -416,15 +419,16 @@ def ab_datetime_parse(dt_str: str | int) -> AirbyteDateTime: except (ValueError, TypeError): raise ValueError(f"Invalid date format: {dt_str}") - # Validate datetime format - if "/" in dt_str or " " in dt_str or "GMT" in dt_str: - raise ValueError(f"Could not parse datetime string: {dt_str}") + # Reject time-only strings without date + if ":" in dt_str and dt_str.count("-") < 2 and dt_str.count("/") < 2: + raise ValueError(f"Missing date part in datetime string: {dt_str}") # Try parsing with dateutil for timezone handling try: parsed = parser.parse(dt_str) if parsed.tzinfo is None: parsed = parsed.replace(tzinfo=timezone.utc) + return AirbyteDateTime.from_datetime(parsed) except (ValueError, TypeError): raise ValueError(f"Could not parse datetime string: {dt_str}") @@ -438,7 +442,29 @@ def ab_datetime_parse(dt_str: str | int) -> AirbyteDateTime: raise ValueError(f"Could not parse datetime string: {dt_str}") -def ab_datetime_format(dt: Union[datetime, AirbyteDateTime]) -> str: +def ab_datetime_try_parse(dt_str: str) -> AirbyteDateTime | None: + """Try to parse the input as a datetime, failing gracefully instead of raising an exception. + + This is a thin wrapper around `ab_datetime_parse()` that catches parsing errors and + returns `None` instead of raising an exception. + The implementation is as flexible as possible to handle various datetime formats. + Always returns a timezone-aware datetime (defaults to `UTC` if no timezone specified). + + Example: + >>> ab_datetime_try_parse("2023-03-14T15:09:26Z") # Returns AirbyteDateTime + >>> ab_datetime_try_parse("2023-03-14 15:09:26Z") # Missing 'T' delimiter still parsable + >>> ab_datetime_try_parse("2023-03-14") # Returns midnight UTC time + """ + try: + return ab_datetime_parse(dt_str) + except (ValueError, TypeError): + return None + + +def ab_datetime_format( + dt: Union[datetime, AirbyteDateTime], + format: str | None = None, +) -> str: """Formats a datetime object as an ISO8601/RFC3339 string with 'T' delimiter and timezone. Previously named: format() @@ -449,6 +475,8 @@ def ab_datetime_format(dt: Union[datetime, AirbyteDateTime]) -> str: Args: dt: Any datetime object to format. + format: Optional format string. If provided, calls `strftime()` with this format. + Otherwise, uses the default ISO8601/RFC3339 format, adapted for available precision. Returns: str: ISO8601/RFC3339 formatted datetime string. @@ -464,54 +492,8 @@ def ab_datetime_format(dt: Union[datetime, AirbyteDateTime]) -> str: if dt.tzinfo is None: dt = dt.replace(tzinfo=timezone.utc) - # Format with consistent timezone representation - base = dt.strftime("%Y-%m-%dT%H:%M:%S") - if dt.microsecond: - base = f"{base}.{dt.microsecond:06d}" - offset = dt.strftime("%z") - return f"{base}{offset[:3]}:{offset[3:]}" - - -def ab_datetime_try_parse(dt_str: str) -> AirbyteDateTime | None: - """Try to parse the input string as an ISO8601/RFC3339 datetime, failing gracefully instead of raising an exception. - - Requires strict ISO8601/RFC3339 format with: - - 'T' delimiter between date and time components - - Valid timezone (Z for UTC or ±HH:MM offset) - - Complete datetime representation (date and time) + if format: + return dt.strftime(format) - Returns None for any non-compliant formats including: - - Space-delimited datetimes - - Date-only strings - - Missing timezone - - Invalid timezone format - - Wrong date/time separators - - Example: - >>> ab_datetime_try_parse("2023-03-14T15:09:26Z") # Returns AirbyteDateTime - >>> ab_datetime_try_parse("2023-03-14 15:09:26Z") # Returns None (invalid format) - >>> ab_datetime_try_parse("2023-03-14") # Returns None (missing time and timezone) - """ - if not isinstance(dt_str, str): - return None - try: - # Validate format before parsing - if "T" not in dt_str: - return None - if not any(x in dt_str for x in ["Z", "+", "-"]): - return None - if "/" in dt_str or " " in dt_str or "GMT" in dt_str: - return None - - # Try parsing with dateutil - parsed = parser.parse(dt_str) - if parsed.tzinfo is None: - return None - - # Validate time components - if not (0 <= parsed.hour <= 23 and 0 <= parsed.minute <= 59 and 0 <= parsed.second <= 59): - return None - - return AirbyteDateTime.from_datetime(parsed) - except (ValueError, TypeError): - return None + # Format with consistent timezone representation and "T" delimiter + return dt.isoformat(sep="T", timespec="auto") diff --git a/unit_tests/utils/test_datetime_helpers.py b/unit_tests/utils/test_datetime_helpers.py index cdc95cf78..88c61ef95 100644 --- a/unit_tests/utils/test_datetime_helpers.py +++ b/unit_tests/utils/test_datetime_helpers.py @@ -51,6 +51,63 @@ def test_now(): @pytest.mark.parametrize( "input_value,expected_output,error_type,error_match", [ + # Valid formats - must have T delimiter and timezone + ("2023-03-14T15:09:26+00:00", "2023-03-14T15:09:26+00:00", None, None), # Basic UTC format + ( + "2023-03-14T15:09:26.123+00:00", + "2023-03-14T15:09:26.123000+00:00", + None, + None, + ), # With milliseconds + ( + "2023-03-14T15:09:26.123456+00:00", + "2023-03-14T15:09:26.123456+00:00", + None, + None, + ), # With microseconds + ( + "2023-03-14T15:09:26-04:00", + "2023-03-14T15:09:26-04:00", + None, + None, + ), # With timezone offset + ("2023-03-14T15:09:26Z", "2023-03-14T15:09:26+00:00", None, None), # With Z timezone + ( + "2023-03-14T00:00:00+00:00", + "2023-03-14T00:00:00+00:00", + None, + None, + ), # Full datetime with zero time + ( + "2023-03-14T15:09:26GMT", + "2023-03-14T15:09:26+00:00", + None, + None, + ), # Non-standard timezone name ok + ( + "2023-03-14T15:09:26", + "2023-03-14T15:09:26+00:00", + None, + None, + ), # Missing timezone, assume UTC + ( + "2023-03-14 15:09:26", + "2023-03-14T15:09:26+00:00", + None, + None, + ), # Missing T delimiter ok, assume UTC + ( + "2023-03-14", + "2023-03-14T00:00:00+00:00", + None, + None, + ), # Date only, missing time and timezone + ( + "2023/03/14T15:09:26Z", + "2023-03-14T15:09:26+00:00", + None, + None, + ), # Wrong date separator, ok # Valid formats ("2023-03-14T15:09:26Z", "2023-03-14T15:09:26+00:00", None, None), ("2023-03-14T15:09:26-04:00", "2023-03-14T15:09:26-04:00", None, None), @@ -71,20 +128,10 @@ def test_now(): ("2023-12-32", None, ValueError, "Invalid date format: 2023-12-32"), ("2023-00-14", None, ValueError, "Invalid date format: 2023-00-14"), ("2023-12-00", None, ValueError, "Invalid date format: 2023-12-00"), - # Invalid separators and formats - ("2023/12/14", None, ValueError, "Could not parse datetime string: 2023/12/14"), - ( - "2023-03-14 15:09:26Z", - None, - ValueError, - "Could not parse datetime string: 2023-03-14 15:09:26Z", - ), - ( - "2023-03-14T15:09:26GMT", - None, - ValueError, - "Could not parse datetime string: 2023-03-14T15:09:26GMT", - ), + # Non-standard separators and formats, ok + ("2023/12/14", "2023-12-14T00:00:00+00:00", None, None), + ("2023-03-14 15:09:26Z", "2023-03-14T15:09:26+00:00", None, None), + ("2023-03-14T15:09:26GMT", "2023-03-14T15:09:26+00:00", None, None), # Invalid time components ( "2023-03-14T25:09:26Z", @@ -105,16 +152,24 @@ def test_now(): "Could not parse datetime string: 2023-03-14T15:09:99Z", ), ], + # ("invalid datetime", None), # Completely invalid + # ("15:09:26Z", None), # Missing date component + # ("2023-03-14T25:09:26Z", None), # Invalid hour + # ("2023-03-14T15:99:26Z", None), # Invalid minute + # ("2023-03-14T15:09:99Z", None), # Invalid second + # ("2023-02-30T00:00:00Z", None), # Impossible date ) def test_parse(input_value, expected_output, error_type, error_match): """Test parsing various datetime string formats.""" if error_type: with pytest.raises(error_type, match=error_match): ab_datetime_parse(input_value) + assert not ab_datetime_try_parse(input_value) else: dt = ab_datetime_parse(input_value) assert isinstance(dt, AirbyteDateTime) assert str(dt) == expected_output + assert ab_datetime_try_parse(input_value) and ab_datetime_try_parse(input_value) == dt @pytest.mark.parametrize( @@ -194,42 +249,6 @@ def test_operator_overloading(): _ = "invalid" - dt -@pytest.mark.parametrize( - "input_value,expected_output", - [ - # Valid formats - must have T delimiter and timezone - ("2023-03-14T15:09:26+00:00", "2023-03-14T15:09:26+00:00"), # Basic UTC format - ("2023-03-14T15:09:26.123+00:00", "2023-03-14T15:09:26.123000+00:00"), # With milliseconds - ( - "2023-03-14T15:09:26.123456+00:00", - "2023-03-14T15:09:26.123456+00:00", - ), # With microseconds - ("2023-03-14T15:09:26-04:00", "2023-03-14T15:09:26-04:00"), # With timezone offset - ("2023-03-14T15:09:26Z", "2023-03-14T15:09:26+00:00"), # With Z timezone - ("2023-03-14T00:00:00+00:00", "2023-03-14T00:00:00+00:00"), # Full datetime with zero time - # Invalid formats - reject anything without proper ISO8601/RFC3339 format - ("invalid datetime", None), # Completely invalid - ("2023-03-14 15:09:26", None), # Missing T delimiter - ("2023-03-14", None), # Date only, missing time and timezone - ("15:09:26Z", None), # Missing date component - ("2023-03-14T15:09:26", None), # Missing timezone - ("2023-03-14T15:09:26GMT", None), # Invalid timezone format - ("2023/03/14T15:09:26Z", None), # Wrong date separator - ("2023-03-14T25:09:26Z", None), # Invalid hour - ("2023-03-14T15:99:26Z", None), # Invalid minute - ("2023-03-14T15:09:99Z", None), # Invalid second - ], -) -def test_ab_datetime_try_parse(input_value, expected_output): - """Test datetime string format validation.""" - result = ab_datetime_try_parse(input_value) - if expected_output is None: - assert result is None - else: - assert isinstance(result, AirbyteDateTime) - assert str(result) == expected_output - - def test_epoch_millis(): """Test Unix epoch millisecond timestamp conversion methods.""" # Test to_epoch_millis() From dea2cc97013f718a0cad86f5b1d39245383fe13c Mon Sep 17 00:00:00 2001 From: Serhii Lazebnyi <53845333+lazebnyi@users.noreply.github.com> Date: Thu, 30 Jan 2025 19:43:23 +0100 Subject: [PATCH 02/12] feat(low-code): added json.loads to jwt authenticator (#301) Co-authored-by: octavia-squidington-iii --- airbyte_cdk/sources/declarative/auth/jwt.py | 28 +++++++++++-------- .../sources/declarative/auth/test_jwt.py | 14 ++++++++++ 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/airbyte_cdk/sources/declarative/auth/jwt.py b/airbyte_cdk/sources/declarative/auth/jwt.py index d7dd59282..c83d081bb 100644 --- a/airbyte_cdk/sources/declarative/auth/jwt.py +++ b/airbyte_cdk/sources/declarative/auth/jwt.py @@ -3,6 +3,7 @@ # import base64 +import json from dataclasses import InitVar, dataclass from datetime import datetime from typing import Any, Mapping, Optional, Union @@ -104,21 +105,21 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: ) def _get_jwt_headers(self) -> dict[str, Any]: - """ " + """ Builds and returns the headers used when signing the JWT. """ - headers = self._additional_jwt_headers.eval(self.config) + headers = self._additional_jwt_headers.eval(self.config, json_loads=json.loads) if any(prop in headers for prop in ["kid", "alg", "typ", "cty"]): raise ValueError( "'kid', 'alg', 'typ', 'cty' are reserved headers and should not be set as part of 'additional_jwt_headers'" ) if self._kid: - headers["kid"] = self._kid.eval(self.config) + headers["kid"] = self._kid.eval(self.config, json_loads=json.loads) if self._typ: - headers["typ"] = self._typ.eval(self.config) + headers["typ"] = self._typ.eval(self.config, json_loads=json.loads) if self._cty: - headers["cty"] = self._cty.eval(self.config) + headers["cty"] = self._cty.eval(self.config, json_loads=json.loads) headers["alg"] = self._algorithm return headers @@ -130,18 +131,19 @@ def _get_jwt_payload(self) -> dict[str, Any]: exp = now + self._token_duration if isinstance(self._token_duration, int) else now nbf = now - payload = self._additional_jwt_payload.eval(self.config) + payload = self._additional_jwt_payload.eval(self.config, json_loads=json.loads) if any(prop in payload for prop in ["iss", "sub", "aud", "iat", "exp", "nbf"]): raise ValueError( "'iss', 'sub', 'aud', 'iat', 'exp', 'nbf' are reserved properties and should not be set as part of 'additional_jwt_payload'" ) if self._iss: - payload["iss"] = self._iss.eval(self.config) + payload["iss"] = self._iss.eval(self.config, json_loads=json.loads) if self._sub: - payload["sub"] = self._sub.eval(self.config) + payload["sub"] = self._sub.eval(self.config, json_loads=json.loads) if self._aud: - payload["aud"] = self._aud.eval(self.config) + payload["aud"] = self._aud.eval(self.config, json_loads=json.loads) + payload["iat"] = now payload["exp"] = exp payload["nbf"] = nbf @@ -151,7 +153,7 @@ def _get_secret_key(self) -> str: """ Returns the secret key used to sign the JWT. """ - secret_key: str = self._secret_key.eval(self.config) + secret_key: str = self._secret_key.eval(self.config, json_loads=json.loads) return ( base64.b64encode(secret_key.encode()).decode() if self._base64_encode_secret_key @@ -176,7 +178,11 @@ def _get_header_prefix(self) -> Union[str, None]: """ Returns the header prefix to be used when attaching the token to the request. """ - return self._header_prefix.eval(self.config) if self._header_prefix else None + return ( + self._header_prefix.eval(self.config, json_loads=json.loads) + if self._header_prefix + else None + ) @property def auth_header(self) -> str: diff --git a/unit_tests/sources/declarative/auth/test_jwt.py b/unit_tests/sources/declarative/auth/test_jwt.py index fe727b980..49b7ea570 100644 --- a/unit_tests/sources/declarative/auth/test_jwt.py +++ b/unit_tests/sources/declarative/auth/test_jwt.py @@ -126,6 +126,20 @@ def test_get_secret_key(self, base64_encode_secret_key, secret_key, expected): ) assert authenticator._get_secret_key() == expected + def test_get_secret_key_from_config( + self, + ): + authenticator = JwtAuthenticator( + config={"secrets": '{"secret_key": "test"}'}, + parameters={}, + secret_key="{{ json_loads(config['secrets'])['secret_key'] }}", + algorithm="test_algo", + token_duration=1200, + base64_encode_secret_key=False, + ) + expected = "test" + assert authenticator._get_secret_key() == expected + def test_get_signed_token(self): authenticator = JwtAuthenticator( config={}, From ee537afe0011c01d5124f1c0c556a8b5ff8ad70e Mon Sep 17 00:00:00 2001 From: Artem Inzhyyants <36314070+artem1205@users.noreply.github.com> Date: Thu, 30 Jan 2025 21:40:16 +0100 Subject: [PATCH 03/12] feat: use create_concurrent_cursor_from_perpartition_cursor (#286) Signed-off-by: Artem Inzhyyants --- .../declarative/async_job/job_orchestrator.py | 8 +++--- .../concurrent_declarative_source.py | 3 ++- .../sources/declarative/declarative_stream.py | 4 ++- .../parsers/model_to_component_factory.py | 27 ++++++++++++++++++- .../async_job_partition_router.py | 10 +++---- .../declarative/retrievers/async_retriever.py | 18 +++++-------- .../async_job/test_job_orchestrator.py | 3 +-- .../test_async_job_partition_router.py | 20 ++++++-------- 8 files changed, 55 insertions(+), 38 deletions(-) diff --git a/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py b/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py index 3938b8c07..398cee9ff 100644 --- a/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py +++ b/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py @@ -482,16 +482,16 @@ def _is_breaking_exception(self, exception: Exception) -> bool: and exception.failure_type == FailureType.config_error ) - def fetch_records(self, partition: AsyncPartition) -> Iterable[Mapping[str, Any]]: + def fetch_records(self, async_jobs: Iterable[AsyncJob]) -> Iterable[Mapping[str, Any]]: """ - Fetches records from the given partition's jobs. + Fetches records from the given jobs. Args: - partition (AsyncPartition): The partition containing the jobs. + async_jobs Iterable[AsyncJob]: The list of AsyncJobs. Yields: Iterable[Mapping[str, Any]]: The fetched records from the jobs. """ - for job in partition.jobs: + for job in async_jobs: yield from self._job_repository.fetch_records(job) self._job_repository.delete(job) diff --git a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py index 3293731fd..92f4bdc4b 100644 --- a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py +++ b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py @@ -19,6 +19,7 @@ from airbyte_cdk.sources.declarative.extractors.record_filter import ( ClientSideIncrementalRecordFilterDecorator, ) +from airbyte_cdk.sources.declarative.incremental import ConcurrentPerPartitionCursor from airbyte_cdk.sources.declarative.incremental.datetime_based_cursor import DatetimeBasedCursor from airbyte_cdk.sources.declarative.incremental.per_partition_with_global import ( PerPartitionWithGlobalCursor, @@ -231,7 +232,7 @@ def _group_streams( ): cursor = declarative_stream.retriever.stream_slicer.stream_slicer - if not isinstance(cursor, ConcurrentCursor): + if not isinstance(cursor, ConcurrentCursor | ConcurrentPerPartitionCursor): # This should never happen since we instantiate ConcurrentCursor in # model_to_component_factory.py raise ValueError( diff --git a/airbyte_cdk/sources/declarative/declarative_stream.py b/airbyte_cdk/sources/declarative/declarative_stream.py index 12cdd3337..f7b97f3b4 100644 --- a/airbyte_cdk/sources/declarative/declarative_stream.py +++ b/airbyte_cdk/sources/declarative/declarative_stream.py @@ -138,7 +138,9 @@ def read_records( """ :param: stream_state We knowingly avoid using stream_state as we want cursors to manage their own state. """ - if stream_slice is None or stream_slice == {}: + if stream_slice is None or ( + not isinstance(stream_slice, StreamSlice) and stream_slice == {} + ): # As the parameter is Optional, many would just call `read_records(sync_mode)` during testing without specifying the field # As part of the declarative model without custom components, this should never happen as the CDK would wire up a # SinglePartitionRouter that would create this StreamSlice properly diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index a8736986e..b8eeca1ec 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -1656,7 +1656,7 @@ def _build_stream_slicer_from_partition_router( ) -> Optional[PartitionRouter]: if ( hasattr(model, "partition_router") - and isinstance(model, SimpleRetrieverModel) + and isinstance(model, SimpleRetrieverModel | AsyncRetrieverModel) and model.partition_router ): stream_slicer_model = model.partition_router @@ -1690,6 +1690,31 @@ def _merge_stream_slicers( stream_slicer = self._build_stream_slicer_from_partition_router(model.retriever, config) if model.incremental_sync and stream_slicer: + if model.retriever.type == "AsyncRetriever": + if model.incremental_sync.type != "DatetimeBasedCursor": + # We are currently in a transition to the Concurrent CDK and AsyncRetriever can only work with the support or unordered slices (for example, when we trigger reports for January and February, the report in February can be completed first). Once we have support for custom concurrent cursor or have a new implementation available in the CDK, we can enable more cursors here. + raise ValueError( + "AsyncRetriever with cursor other than DatetimeBasedCursor is not supported yet" + ) + if stream_slicer: + return self.create_concurrent_cursor_from_perpartition_cursor( # type: ignore # This is a known issue that we are creating and returning a ConcurrentCursor which does not technically implement the (low-code) StreamSlicer. However, (low-code) StreamSlicer and ConcurrentCursor both implement StreamSlicer.stream_slices() which is the primary method needed for checkpointing + state_manager=self._connector_state_manager, + model_type=DatetimeBasedCursorModel, + component_definition=model.incremental_sync.__dict__, + stream_name=model.name or "", + stream_namespace=None, + config=config or {}, + stream_state={}, + partition_router=stream_slicer, + ) + return self.create_concurrent_cursor_from_datetime_based_cursor( # type: ignore # This is a known issue that we are creating and returning a ConcurrentCursor which does not technically implement the (low-code) StreamSlicer. However, (low-code) StreamSlicer and ConcurrentCursor both implement StreamSlicer.stream_slices() which is the primary method needed for checkpointing + model_type=DatetimeBasedCursorModel, + component_definition=model.incremental_sync.__dict__, + stream_name=model.name or "", + stream_namespace=None, + config=config or {}, + ) + incremental_sync_model = model.incremental_sync if ( hasattr(incremental_sync_model, "global_substream_cursor") diff --git a/airbyte_cdk/sources/declarative/partition_routers/async_job_partition_router.py b/airbyte_cdk/sources/declarative/partition_routers/async_job_partition_router.py index 0f11820f7..38a4f5328 100644 --- a/airbyte_cdk/sources/declarative/partition_routers/async_job_partition_router.py +++ b/airbyte_cdk/sources/declarative/partition_routers/async_job_partition_router.py @@ -4,9 +4,9 @@ from typing import Any, Callable, Iterable, Mapping, Optional from airbyte_cdk.models import FailureType +from airbyte_cdk.sources.declarative.async_job.job import AsyncJob from airbyte_cdk.sources.declarative.async_job.job_orchestrator import ( AsyncJobOrchestrator, - AsyncPartition, ) from airbyte_cdk.sources.declarative.partition_routers.single_partition_router import ( SinglePartitionRouter, @@ -42,12 +42,12 @@ def stream_slices(self) -> Iterable[StreamSlice]: for completed_partition in self._job_orchestrator.create_and_get_completed_partitions(): yield StreamSlice( - partition=dict(completed_partition.stream_slice.partition) - | {"partition": completed_partition}, + partition=dict(completed_partition.stream_slice.partition), cursor_slice=completed_partition.stream_slice.cursor_slice, + extra_fields={"jobs": list(completed_partition.jobs)}, ) - def fetch_records(self, partition: AsyncPartition) -> Iterable[Mapping[str, Any]]: + def fetch_records(self, async_jobs: Iterable[AsyncJob]) -> Iterable[Mapping[str, Any]]: """ This method of fetching records extends beyond what a PartitionRouter/StreamSlicer should be responsible for. However, this was added in because the JobOrchestrator is required to @@ -62,4 +62,4 @@ def fetch_records(self, partition: AsyncPartition) -> Iterable[Mapping[str, Any] failure_type=FailureType.system_error, ) - return self._job_orchestrator.fetch_records(partition=partition) + return self._job_orchestrator.fetch_records(async_jobs=async_jobs) diff --git a/airbyte_cdk/sources/declarative/retrievers/async_retriever.py b/airbyte_cdk/sources/declarative/retrievers/async_retriever.py index bd28e0e2d..24f52cfd3 100644 --- a/airbyte_cdk/sources/declarative/retrievers/async_retriever.py +++ b/airbyte_cdk/sources/declarative/retrievers/async_retriever.py @@ -6,7 +6,7 @@ from typing_extensions import deprecated -from airbyte_cdk.models import FailureType +from airbyte_cdk.sources.declarative.async_job.job import AsyncJob from airbyte_cdk.sources.declarative.async_job.job_orchestrator import AsyncPartition from airbyte_cdk.sources.declarative.extractors.record_selector import RecordSelector from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import ( @@ -16,7 +16,6 @@ from airbyte_cdk.sources.source import ExperimentalClassWarning from airbyte_cdk.sources.streams.core import StreamData from airbyte_cdk.sources.types import Config, StreamSlice, StreamState -from airbyte_cdk.utils.traced_exception import AirbyteTracedException @deprecated( @@ -57,9 +56,9 @@ def _get_stream_state(self) -> StreamState: return self.state - def _validate_and_get_stream_slice_partition( + def _validate_and_get_stream_slice_jobs( self, stream_slice: Optional[StreamSlice] = None - ) -> AsyncPartition: + ) -> Iterable[AsyncJob]: """ Validates the stream_slice argument and returns the partition from it. @@ -73,12 +72,7 @@ def _validate_and_get_stream_slice_partition( AirbyteTracedException: If the stream_slice is not an instance of StreamSlice or if the partition is not present in the stream_slice. """ - if not isinstance(stream_slice, StreamSlice) or "partition" not in stream_slice.partition: - raise AirbyteTracedException( - message="Invalid arguments to AsyncRetriever.read_records: stream_slice is not optional. Please contact Airbyte Support", - failure_type=FailureType.system_error, - ) - return stream_slice["partition"] # type: ignore # stream_slice["partition"] has been added as an AsyncPartition as part of stream_slices + return stream_slice.extra_fields.get("jobs", []) if stream_slice else [] def stream_slices(self) -> Iterable[Optional[StreamSlice]]: return self.stream_slicer.stream_slices() @@ -89,8 +83,8 @@ def read_records( stream_slice: Optional[StreamSlice] = None, ) -> Iterable[StreamData]: stream_state: StreamState = self._get_stream_state() - partition: AsyncPartition = self._validate_and_get_stream_slice_partition(stream_slice) - records: Iterable[Mapping[str, Any]] = self.stream_slicer.fetch_records(partition) + jobs: Iterable[AsyncJob] = self._validate_and_get_stream_slice_jobs(stream_slice) + records: Iterable[Mapping[str, Any]] = self.stream_slicer.fetch_records(jobs) yield from self.record_selector.filter_and_transform( all_data=records, diff --git a/unit_tests/sources/declarative/async_job/test_job_orchestrator.py b/unit_tests/sources/declarative/async_job/test_job_orchestrator.py index d2fb9018f..dc81eacbc 100644 --- a/unit_tests/sources/declarative/async_job/test_job_orchestrator.py +++ b/unit_tests/sources/declarative/async_job/test_job_orchestrator.py @@ -174,9 +174,8 @@ def test_when_fetch_records_then_yield_records_from_each_job(self) -> None: orchestrator = self._orchestrator([_A_STREAM_SLICE]) first_job = _create_job() second_job = _create_job() - partition = AsyncPartition([first_job, second_job], _A_STREAM_SLICE) - records = list(orchestrator.fetch_records(partition)) + records = list(orchestrator.fetch_records([first_job, second_job])) assert len(records) == 2 assert self._job_repository.fetch_records.mock_calls == [call(first_job), call(second_job)] diff --git a/unit_tests/sources/declarative/partition_routers/test_async_job_partition_router.py b/unit_tests/sources/declarative/partition_routers/test_async_job_partition_router.py index ccc57cc91..2a5ac3277 100644 --- a/unit_tests/sources/declarative/partition_routers/test_async_job_partition_router.py +++ b/unit_tests/sources/declarative/partition_routers/test_async_job_partition_router.py @@ -35,12 +35,12 @@ def test_stream_slices_with_single_partition_router(): slices = list(partition_router.stream_slices()) assert len(slices) == 1 - partition = slices[0].partition.get("partition") - assert isinstance(partition, AsyncPartition) - assert partition.stream_slice == StreamSlice(partition={}, cursor_slice={}) - assert partition.status == AsyncJobStatus.COMPLETED + partition = slices[0] + assert isinstance(partition, StreamSlice) + assert partition == StreamSlice(partition={}, cursor_slice={}) + assert partition.extra_fields["jobs"][0].status() == AsyncJobStatus.COMPLETED - attempts_per_job = list(partition.jobs) + attempts_per_job = list(partition.extra_fields["jobs"]) assert len(attempts_per_job) == 1 assert attempts_per_job[0].api_job_id() == "a_job_id" assert attempts_per_job[0].job_parameters() == StreamSlice(partition={}, cursor_slice={}) @@ -68,14 +68,10 @@ def test_stream_slices_with_parent_slicer(): slices = list(partition_router.stream_slices()) assert len(slices) == 3 for i, partition in enumerate(slices): - partition = partition.partition.get("partition") - assert isinstance(partition, AsyncPartition) - assert partition.stream_slice == StreamSlice( - partition={"parent_id": str(i)}, cursor_slice={} - ) - assert partition.status == AsyncJobStatus.COMPLETED + assert isinstance(partition, StreamSlice) + assert partition == StreamSlice(partition={"parent_id": str(i)}, cursor_slice={}) - attempts_per_job = list(partition.jobs) + attempts_per_job = list(partition.extra_fields["jobs"]) assert len(attempts_per_job) == 1 assert attempts_per_job[0].api_job_id() == "a_job_id" assert attempts_per_job[0].job_parameters() == StreamSlice( From 65e6a0d12e19a84db360053d3ec793c66b203512 Mon Sep 17 00:00:00 2001 From: Baz Date: Fri, 31 Jan 2025 00:33:43 +0200 Subject: [PATCH 04/12] fix: (OAuthAuthenticator) - get the `access_token`, `refresh_token`, `expires_in` recursively from `response` (#285) --- .../requests_native_auth/abstract_oauth.py | 270 ++++++++++++++---- .../http/requests_native_auth/oauth.py | 217 +++++++++----- .../test_connector_builder_handler.py | 2 +- .../test_requests_native_auth.py | 86 +++++- 4 files changed, 442 insertions(+), 133 deletions(-) diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py index dd2b3057b..0a9b15bc0 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py @@ -25,6 +25,13 @@ _NOOP_MESSAGE_REPOSITORY = NoopMessageRepository() +class ResponseKeysMaxRecurtionReached(AirbyteTracedException): + """ + Raised when the max level of recursion is reached, when trying to + find-and-get the target key, during the `_make_handled_request` + """ + + class AbstractOauth2Authenticator(AuthBase): """ Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator @@ -53,15 +60,31 @@ def __call__(self, request: requests.PreparedRequest) -> requests.PreparedReques request.headers.update(self.get_auth_header()) return request + @property + def _is_access_token_flow(self) -> bool: + return self.get_token_refresh_endpoint() is None and self.access_token is not None + + @property + def token_expiry_is_time_of_expiration(self) -> bool: + """ + Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid. + """ + + return False + + @property + def token_expiry_date_format(self) -> Optional[str]: + """ + Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires + """ + + return None + def get_auth_header(self) -> Mapping[str, Any]: """HTTP header to set on the requests""" token = self.access_token if self._is_access_token_flow else self.get_access_token() return {"Authorization": f"Bearer {token}"} - @property - def _is_access_token_flow(self) -> bool: - return self.get_token_refresh_endpoint() is None and self.access_token is not None - def get_access_token(self) -> str: """Returns the access token""" if self.token_has_expired(): @@ -107,9 +130,39 @@ def build_refresh_request_headers(self) -> Mapping[str, Any] | None: headers = self.get_refresh_request_headers() return headers if headers else None + def refresh_access_token(self) -> Tuple[str, Union[str, int]]: + """ + Returns the refresh token and its expiration datetime + + :return: a tuple of (access_token, token_lifespan) + """ + response_json = self._make_handled_request() + self._ensure_access_token_in_response(response_json) + + return ( + self._extract_access_token(response_json), + self._extract_token_expiry_date(response_json), + ) + + # ---------------- + # PRIVATE METHODS + # ---------------- + def _wrap_refresh_token_exception( self, exception: requests.exceptions.RequestException ) -> bool: + """ + Wraps and handles exceptions that occur during the refresh token process. + + This method checks if the provided exception is related to a refresh token error + by examining the response status code and specific error content. + + Args: + exception (requests.exceptions.RequestException): The exception raised during the request. + + Returns: + bool: True if the exception is related to a refresh token error, False otherwise. + """ try: if exception.response is not None: exception_content = exception.response.json() @@ -131,7 +184,24 @@ def _wrap_refresh_token_exception( ), max_time=300, ) - def _get_refresh_access_token_response(self) -> Any: + def _make_handled_request(self) -> Any: + """ + Makes a handled HTTP request to refresh an OAuth token. + + This method sends a POST request to the token refresh endpoint with the necessary + headers and body to obtain a new access token. It handles various exceptions that + may occur during the request and logs the response for troubleshooting purposes. + + Returns: + Mapping[str, Any]: The JSON response from the token refresh endpoint. + + Raises: + DefaultBackoffException: If the response status code is 429 (Too Many Requests) + or any 5xx server error. + AirbyteTracedException: If the refresh token is invalid or expired, prompting + re-authentication. + Exception: For any other exceptions that occur during the request. + """ try: response = requests.request( method="POST", @@ -139,22 +209,10 @@ def _get_refresh_access_token_response(self) -> Any: data=self.build_refresh_request_body(), headers=self.build_refresh_request_headers(), ) - if response.ok: - response_json = response.json() - # Add the access token to the list of secrets so it is replaced before logging the response - # An argument could be made to remove the prevous access key from the list of secrets, but unmasking values seems like a security incident waiting to happen... - access_key = response_json.get(self.get_access_token_name()) - if not access_key: - raise Exception( - "Token refresh API response was missing access token {self.get_access_token_name()}" - ) - add_to_secrets(access_key) - self._log_response(response) - return response_json - else: - # log the response even if the request failed for troubleshooting purposes - self._log_response(response) - response.raise_for_status() + # log the response even if the request failed for troubleshooting purposes + self._log_response(response) + response.raise_for_status() + return response.json() except requests.exceptions.RequestException as e: if e.response is not None: if e.response.status_code == 429 or e.response.status_code >= 500: @@ -168,17 +226,34 @@ def _get_refresh_access_token_response(self) -> Any: except Exception as e: raise Exception(f"Error while refreshing access token: {e}") from e - def refresh_access_token(self) -> Tuple[str, Union[str, int]]: + def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None: """ - Returns the refresh token and its expiration datetime + Ensures that the access token is present in the response data. - :return: a tuple of (access_token, token_lifespan) - """ - response_json = self._get_refresh_access_token_response() + This method attempts to extract the access token from the provided response data. + If the access token is not found, it raises an exception indicating that the token + refresh API response was missing the access token. If the access token is found, + it adds the token to the list of secrets to ensure it is replaced before logging + the response. + + Args: + response_data (Mapping[str, Any]): The response data from which to extract the access token. - return response_json[self.get_access_token_name()], response_json[ - self.get_expires_in_name() - ] + Raises: + Exception: If the access token is not found in the response data. + ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token. + """ + try: + access_key = self._extract_access_token(response_data) + if not access_key: + raise Exception( + "Token refresh API response was missing access token {self.get_access_token_name()}" + ) + # Add the access token to the list of secrets so it is replaced before logging the response + # An argument could be made to remove the prevous access key from the list of secrets, but unmasking values seems like a security incident waiting to happen... + add_to_secrets(access_key) + except ResponseKeysMaxRecurtionReached as e: + raise e def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime: """ @@ -206,22 +281,125 @@ def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTim f"Invalid expires_in value: {value}. Expected number of seconds when no format specified." ) - @property - def token_expiry_is_time_of_expiration(self) -> bool: + def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any: """ - Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid. + Extracts the access token from the given response data. + + Args: + response_data (Mapping[str, Any]): The response data from which to extract the access token. + + Returns: + str: The extracted access token. """ + return self._find_and_get_value_from_response(response_data, self.get_access_token_name()) - return False + def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any: + """ + Extracts the refresh token from the given response data. - @property - def token_expiry_date_format(self) -> Optional[str]: + Args: + response_data (Mapping[str, Any]): The response data from which to extract the refresh token. + + Returns: + str: The extracted refresh token. """ - Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires + return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name()) + + def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> Any: + """ + Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data. + + Args: + response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date. + + Returns: + str: The extracted token_expiry_date. """ + return self._find_and_get_value_from_response(response_data, self.get_expires_in_name()) + + def _find_and_get_value_from_response( + self, + response_data: Mapping[str, Any], + key_name: str, + max_depth: int = 5, + current_depth: int = 0, + ) -> Any: + """ + Recursively searches for a specified key in a nested dictionary or list and returns its value if found. + + Args: + response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list. + key_name (str): The key to search for in the response data. + max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5. + current_depth (int, optional): The current depth of the recursion. Defaults to 0. + + Returns: + Any: The value associated with the specified key if found, otherwise None. + + Raises: + AirbyteTracedException: If the maximum recursion depth is reached without finding the key. + """ + if current_depth > max_depth: + # this is needed to avoid an inf loop, possible with a very deep nesting observed. + message = f"The maximum level of recursion is reached. Couldn't find the speficied `{key_name}` in the response." + raise ResponseKeysMaxRecurtionReached( + internal_message=message, message=message, failure_type=FailureType.config_error + ) + + if isinstance(response_data, dict): + # get from the root level + if key_name in response_data: + return response_data[key_name] + + # get from the nested object + for _, value in response_data.items(): + result = self._find_and_get_value_from_response( + value, key_name, max_depth, current_depth + 1 + ) + if result is not None: + return result + + # get from the nested array object + elif isinstance(response_data, list): + for item in response_data: + result = self._find_and_get_value_from_response( + item, key_name, max_depth, current_depth + 1 + ) + if result is not None: + return result return None + @property + def _message_repository(self) -> Optional[MessageRepository]: + """ + The implementation can define a message_repository if it wants debugging logs for HTTP requests + """ + return _NOOP_MESSAGE_REPOSITORY + + def _log_response(self, response: requests.Response) -> None: + """ + Logs the HTTP response using the message repository if it is available. + + Args: + response (requests.Response): The HTTP response to log. + """ + if self._message_repository: + self._message_repository.log_message( + Level.DEBUG, + lambda: format_http_message( + response, + "Refresh token", + "Obtains access token", + self._NO_STREAM_NAME, + is_auxiliary=True, + ), + ) + + # ---------------- + # ABSTR METHODS + # ---------------- + @abstractmethod def get_token_refresh_endpoint(self) -> Optional[str]: """Returns the endpoint to refresh the access token""" @@ -295,23 +473,3 @@ def access_token(self) -> str: @abstractmethod def access_token(self, value: str) -> str: """Setter for the access token""" - - @property - def _message_repository(self) -> Optional[MessageRepository]: - """ - The implementation can define a message_repository if it wants debugging logs for HTTP requests - """ - return _NOOP_MESSAGE_REPOSITORY - - def _log_response(self, response: requests.Response) -> None: - if self._message_repository: - self._message_repository.log_message( - Level.DEBUG, - lambda: format_http_message( - response, - "Refresh token", - "Obtains access token", - self._NO_STREAM_NAME, - is_auxiliary=True, - ), - ) diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py index 5cbe17e0a..2ff2f60e9 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py @@ -51,7 +51,7 @@ def __init__( refresh_token_error_status_codes: Tuple[int, ...] = (), refresh_token_error_key: str = "", refresh_token_error_values: Tuple[str, ...] = (), - ): + ) -> None: self._token_refresh_endpoint = token_refresh_endpoint self._client_secret_name = client_secret_name self._client_secret = client_secret @@ -175,7 +175,7 @@ def __init__( refresh_token_error_status_codes: Tuple[int, ...] = (), refresh_token_error_key: str = "", refresh_token_error_values: Tuple[str, ...] = (), - ): + ) -> None: """ Args: connector_config (Mapping[str, Any]): The full connector configuration @@ -196,18 +196,12 @@ def __init__( token_expiry_is_time_of_expiration bool: set True it if expires_in is returned as time of expiration instead of the number seconds until expiration message_repository (MessageRepository): the message repository used to emit logs on HTTP requests and control message on config update """ - self._client_id = ( - client_id # type: ignore[assignment] # Incorrect type for assignment - if client_id is not None - else dpath.get(connector_config, ("credentials", "client_id")) # type: ignore[arg-type] + self._connector_config = connector_config + self._client_id: str = self._get_config_value_by_path( + ("credentials", "client_id"), client_id ) - self._client_secret = ( - client_secret # type: ignore[assignment] # Incorrect type for assignment - if client_secret is not None - else dpath.get( - connector_config, # type: ignore[arg-type] - ("credentials", "client_secret"), - ) + self._client_secret: str = self._get_config_value_by_path( + ("credentials", "client_secret"), client_secret ) self._client_id_name = client_id_name self._client_secret_name = client_secret_name @@ -222,9 +216,9 @@ def __init__( super().__init__( token_refresh_endpoint=token_refresh_endpoint, client_id_name=self._client_id_name, - client_id=self.get_client_id(), + client_id=self._client_id, client_secret_name=self._client_secret_name, - client_secret=self.get_client_secret(), + client_secret=self._client_secret, refresh_token=self.get_refresh_token(), refresh_token_name=self._refresh_token_name, scopes=scopes, @@ -242,51 +236,62 @@ def __init__( refresh_token_error_values=refresh_token_error_values, ) - def get_refresh_token_name(self) -> str: - return self._refresh_token_name - - def get_client_id(self) -> str: - return self._client_id - - def get_client_secret(self) -> str: - return self._client_secret - @property def access_token(self) -> str: - return dpath.get( # type: ignore[return-value] - self._connector_config, # type: ignore[arg-type] - self._access_token_config_path, - default="", - ) + """ + Retrieve the access token from the configuration. + + Returns: + str: The access token. + """ + return self._get_config_value_by_path(self._access_token_config_path) # type: ignore[return-value] @access_token.setter def access_token(self, new_access_token: str) -> None: - dpath.new( - self._connector_config, # type: ignore[arg-type] - self._access_token_config_path, - new_access_token, - ) + """ + Sets a new access token. + + Args: + new_access_token (str): The new access token to be set. + """ + self._set_config_value_by_path(self._access_token_config_path, new_access_token) def get_refresh_token(self) -> str: - return dpath.get( # type: ignore[return-value] - self._connector_config, # type: ignore[arg-type] - self._refresh_token_config_path, - default="", - ) + """ + Retrieve the refresh token from the configuration. + + This method fetches the refresh token using the configuration path specified + by `_refresh_token_config_path`. + + Returns: + str: The refresh token as a string. + """ + return self._get_config_value_by_path(self._refresh_token_config_path) # type: ignore[return-value] def set_refresh_token(self, new_refresh_token: str) -> None: - dpath.new( - self._connector_config, # type: ignore[arg-type] - self._refresh_token_config_path, - new_refresh_token, - ) + """ + Updates the refresh token in the configuration. + + Args: + new_refresh_token (str): The new refresh token to be set. + """ + self._set_config_value_by_path(self._refresh_token_config_path, new_refresh_token) def get_token_expiry_date(self) -> AirbyteDateTime: - expiry_date = dpath.get( - self._connector_config, # type: ignore[arg-type] - self._token_expiry_date_config_path, - default="", - ) + """ + Retrieves the token expiry date from the configuration. + + This method fetches the token expiry date from the configuration using the specified path. + If the expiry date is an empty string, it returns the current date and time minus one day. + Otherwise, it parses the expiry date string into an AirbyteDateTime object. + + Returns: + AirbyteDateTime: The parsed or calculated token expiry date. + + Raises: + TypeError: If the result is not an instance of AirbyteDateTime. + """ + expiry_date = self._get_config_value_by_path(self._token_expiry_date_config_path) result = ( ab_datetime_now() - timedelta(days=1) if expiry_date == "" @@ -296,14 +301,15 @@ def get_token_expiry_date(self) -> AirbyteDateTime: return result raise TypeError("Invalid datetime conversion") - def set_token_expiry_date( # type: ignore[override] - self, - new_token_expiry_date: AirbyteDateTime, - ) -> None: - dpath.new( - self._connector_config, # type: ignore[arg-type] - self._token_expiry_date_config_path, - str(new_token_expiry_date), + def set_token_expiry_date(self, new_token_expiry_date: AirbyteDateTime) -> None: # type: ignore[override] + """ + Sets the token expiry date in the configuration. + + Args: + new_token_expiry_date (AirbyteDateTime): The new expiry date for the token. + """ + self._set_config_value_by_path( + self._token_expiry_date_config_path, str(new_token_expiry_date) ) def token_has_expired(self) -> bool: @@ -315,6 +321,16 @@ def get_new_token_expiry_date( access_token_expires_in: str, token_expiry_date_format: str | None = None, ) -> AirbyteDateTime: + """ + Calculate the new token expiry date based on the provided expiration duration or format. + + Args: + access_token_expires_in (str): The duration (in seconds) until the access token expires, or the expiry date in a specific format. + token_expiry_date_format (str | None, optional): The format of the expiry date if provided. Defaults to None. + + Returns: + AirbyteDateTime: The calculated expiry date of the access token. + """ if token_expiry_date_format: return ab_datetime_parse(access_token_expires_in) else: @@ -336,27 +352,82 @@ def get_access_token(self) -> str: self.access_token = new_access_token self.set_refresh_token(new_refresh_token) self.set_token_expiry_date(new_token_expiry_date) - # FIXME emit_configuration_as_airbyte_control_message as been deprecated in favor of package airbyte_cdk.sources.message - # Usually, a class shouldn't care about the implementation details but to keep backward compatibility where we print the - # message directly in the console, this is needed - if not isinstance(self._message_repository, NoopMessageRepository): - self._message_repository.emit_message( - create_connector_config_control_message(self._connector_config) # type: ignore[arg-type] - ) - else: - emit_configuration_as_airbyte_control_message(self._connector_config) # type: ignore[arg-type] + self._emit_control_message() return self.access_token - def refresh_access_token( # type: ignore[override] # Signature doesn't match base class - self, - ) -> Tuple[str, str, str]: - response_json = self._get_refresh_access_token_response() + def refresh_access_token(self) -> Tuple[str, str, str]: # type: ignore[override] + """ + Refreshes the access token by making a handled request and extracting the necessary token information. + + Returns: + Tuple[str, str, str]: A tuple containing the new access token, token expiry date, and refresh token. + """ + response_json = self._make_handled_request() return ( - response_json[self.get_access_token_name()], - response_json[self.get_expires_in_name()], - response_json[self.get_refresh_token_name()], + self._extract_access_token(response_json), + self._extract_token_expiry_date(response_json), + self._extract_refresh_token(response_json), + ) + + def _set_config_value_by_path(self, config_path: Union[str, Sequence[str]], value: Any) -> None: + """ + Set a value in the connector configuration at the specified path. + + Args: + config_path (Union[str, Sequence[str]]): The path within the configuration where the value should be set. + This can be a string representing a single key or a sequence of strings representing a nested path. + value (Any): The value to set at the specified path in the configuration. + + Returns: + None + """ + dpath.new(self._connector_config, config_path, value) # type: ignore[arg-type] + + def _get_config_value_by_path( + self, config_path: Union[str, Sequence[str]], default: Optional[str] = None + ) -> str | Any: + """ + Retrieve a value from the connector configuration using a specified path. + + Args: + config_path (Union[str, Sequence[str]]): The path to the desired configuration value. This can be a string or a sequence of strings. + default (Optional[str], optional): The default value to return if the specified path does not exist in the configuration. Defaults to None. + + Returns: + Any: The value from the configuration at the specified path, or the default value if the path does not exist. + """ + return dpath.get( + self._connector_config, # type: ignore[arg-type] + config_path, + default=default if default is not None else "", ) + def _emit_control_message(self) -> None: + """ + Emits a control message based on the connector configuration. + + This method checks if the message repository is not a NoopMessageRepository. + If it is not, it emits a message using the message repository. Otherwise, + it falls back to emitting the configuration as an Airbyte control message + directly to the console for backward compatibility. + + Note: + The function `emit_configuration_as_airbyte_control_message` has been deprecated + in favor of the package `airbyte_cdk.sources.message`. + + Raises: + TypeError: If the argument types are incorrect. + """ + # FIXME emit_configuration_as_airbyte_control_message as been deprecated in favor of package airbyte_cdk.sources.message + # Usually, a class shouldn't care about the implementation details but to keep backward compatibility where we print the + # message directly in the console, this is needed + if not isinstance(self._message_repository, NoopMessageRepository): + self._message_repository.emit_message( + create_connector_config_control_message(self._connector_config) # type: ignore[arg-type] + ) + else: + emit_configuration_as_airbyte_control_message(self._connector_config) # type: ignore[arg-type] + @property def _message_repository(self) -> MessageRepository: """ diff --git a/unit_tests/connector_builder/test_connector_builder_handler.py b/unit_tests/connector_builder/test_connector_builder_handler.py index c00a7e2f1..b0c91ce30 100644 --- a/unit_tests/connector_builder/test_connector_builder_handler.py +++ b/unit_tests/connector_builder/test_connector_builder_handler.py @@ -600,7 +600,7 @@ def test_config_update() -> None: "expires_in": 3600, } with patch( - "airbyte_cdk.sources.streams.http.requests_native_auth.SingleUseRefreshTokenOauth2Authenticator._get_refresh_access_token_response", + "airbyte_cdk.sources.streams.http.requests_native_auth.SingleUseRefreshTokenOauth2Authenticator._make_handled_request", return_value=refresh_request_response, ): output = handle_connector_builder_request( diff --git a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py index 808126988..d756931c8 100644 --- a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py +++ b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py @@ -22,6 +22,9 @@ SingleUseRefreshTokenOauth2Authenticator, TokenAuthenticator, ) +from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth import ( + ResponseKeysMaxRecurtionReached, +) from airbyte_cdk.utils import AirbyteTracedException from airbyte_cdk.utils.datetime_helpers import AirbyteDateTime, ab_datetime_now, ab_datetime_parse @@ -258,7 +261,7 @@ def test_refresh_access_token(self, mocker): assert isinstance(expires_in, int) assert ("access_token", 1000) == (token, expires_in) - # Test with expires_in as str + # Test with expires_in as str(int) mocker.patch.object( resp, "json", return_value={"access_token": "access_token", "expires_in": "2000"} ) @@ -267,7 +270,7 @@ def test_refresh_access_token(self, mocker): assert isinstance(expires_in, str) assert ("access_token", "2000") == (token, expires_in) - # Test with expires_in as str + # Test with expires_in as datetime(str) mocker.patch.object( resp, "json", @@ -278,6 +281,78 @@ def test_refresh_access_token(self, mocker): assert isinstance(expires_in, str) assert ("access_token", "2022-04-24T00:00:00Z") == (token, expires_in) + # Test with nested access_token and expires_in as str(int) + mocker.patch.object( + resp, + "json", + return_value={"data": {"access_token": "access_token_nested", "expires_in": "2001"}}, + ) + token, expires_in = oauth.refresh_access_token() + + assert isinstance(expires_in, str) + assert ("access_token_nested", "2001") == (token, expires_in) + + # Test with multiple nested levels access_token and expires_in as str(int) + mocker.patch.object( + resp, + "json", + return_value={ + "data": { + "scopes": ["one", "two", "three"], + "data2": { + "not_access_token": "test_non_access_token_value", + "data3": { + "some_field": "test_value", + "expires_at": "2800", + "data4": { + "data5": { + "access_token": "access_token_deeply_nested", + "expires_in": "2002", + } + }, + }, + }, + } + }, + ) + token, expires_in = oauth.refresh_access_token() + + assert isinstance(expires_in, str) + assert ("access_token_deeply_nested", "2002") == (token, expires_in) + + # Test with max nested levels access_token and expires_in as str(int) + mocker.patch.object( + resp, + "json", + return_value={ + "data": { + "scopes": ["one", "two", "three"], + "data2": { + "not_access_token": "test_non_access_token_value", + "data3": { + "some_field": "test_value", + "expires_at": "2800", + "data4": { + "data5": { + # this is the edge case, but worth testing. + "data6": { + "access_token": "access_token_super_deeply_nested", + "expires_in": "2003", + } + } + }, + }, + }, + } + }, + ) + with pytest.raises(ResponseKeysMaxRecurtionReached) as exc_info: + oauth.refresh_access_token() + error_message = "The maximum level of recursion is reached. Couldn't find the speficied `access_token` in the response." + assert exc_info.value.internal_message == error_message + assert exc_info.value.message == error_message + assert exc_info.value.failure_type == FailureType.config_error + def test_refresh_access_token_when_headers_provided(self, mocker): expected_headers = { "Authorization": "Bearer some_access_token", @@ -594,6 +669,11 @@ def test_given_message_repository_when_get_access_token_then_log_request( "airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth.format_http_message", return_value="formatted json", ) + # patching the `expires_in` + mocker.patch( + "airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth.AbstractOauth2Authenticator._find_and_get_value_from_response", + return_value="7200", + ) authenticator.token_has_expired = mocker.Mock(return_value=True) authenticator.get_access_token() @@ -608,7 +688,7 @@ def test_refresh_access_token(self, mocker, connector_config): client_secret=connector_config["credentials"]["client_secret"], ) - authenticator._get_refresh_access_token_response = mocker.Mock( + authenticator._make_handled_request = mocker.Mock( return_value={ authenticator.get_access_token_name(): "new_access_token", authenticator.get_expires_in_name(): "42", From 10a7a873b9212b7e1af230ed13ce833c3741a3c4 Mon Sep 17 00:00:00 2001 From: Patrick Nilan Date: Fri, 31 Jan 2025 09:23:01 -0800 Subject: [PATCH 05/12] fix: Relocates `self._extract_slice_fields(stream_slice=stream_slice)` within conditional codeblock that uses it. (#304) --- airbyte_cdk/sources/streams/http/http.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/airbyte_cdk/sources/streams/http/http.py b/airbyte_cdk/sources/streams/http/http.py index 40eab27a3..fbf4fe35d 100644 --- a/airbyte_cdk/sources/streams/http/http.py +++ b/airbyte_cdk/sources/streams/http/http.py @@ -423,8 +423,6 @@ def _read_pages( stream_slice: Optional[Mapping[str, Any]] = None, stream_state: Optional[Mapping[str, Any]] = None, ) -> Iterable[StreamData]: - partition, _, _ = self._extract_slice_fields(stream_slice=stream_slice) - stream_state = stream_state or {} pagination_complete = False next_page_token = None @@ -438,6 +436,7 @@ def _read_pages( cursor = self.get_cursor() if cursor and isinstance(cursor, SubstreamResumableFullRefreshCursor): + partition, _, _ = self._extract_slice_fields(stream_slice=stream_slice) # Substreams checkpoint state by marking an entire parent partition as completed so that on the subsequent attempt # after a failure, completed parents are skipped and the sync can make progress cursor.close_slice(StreamSlice(cursor_slice={}, partition=partition)) From e57d38a515df75860f44b67e44cfe5a4e7269b8f Mon Sep 17 00:00:00 2001 From: "devin-ai-integration[bot]" <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sat, 1 Feb 2025 03:30:33 +0000 Subject: [PATCH 06/12] docs: add migration note for pendulum removal (#308) Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: Aaron Steers --- cdk-migrations.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/cdk-migrations.md b/cdk-migrations.md index fb2163ef8..8173a5edd 100644 --- a/cdk-migrations.md +++ b/cdk-migrations.md @@ -1,5 +1,12 @@ # CDK Migration Guide +## Upgrading to 6.28.0 + +Starting from version 6.28.0, the CDK no longer includes Pendulum as a transitive dependency. If your connector relies on Pendulum without explicitly declaring it as a dependency, you will need to add it to your connector's dependencies going forward. + +More info: +- https://deptry.com/rules-violations/#transitive-dependencies-dep003 + ## Upgrading to 6.0.0 Version 6.x.x of the CDK introduces concurrent processing of low-code incremental streams. This is breaking because non-manifest only connectors must update their self-managed `run.py` and `source.py` files. This section is intended to clarify how to upgrade a low-code connector to use the Concurrent CDK to sync incremental streams. From 426ab5b65cf8ed45e8c2b7bc236490aba5e6c624 Mon Sep 17 00:00:00 2001 From: Serhii Lazebnyi <53845333+lazebnyi@users.noreply.github.com> Date: Sat, 1 Feb 2025 20:27:31 +0100 Subject: [PATCH 07/12] fix(low-code): add wrong dynamic stream name type validation (#305) --- .../manifest_declarative_source.py | 5 ++++ .../test_http_components_resolver.py | 29 +++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/airbyte_cdk/sources/declarative/manifest_declarative_source.py b/airbyte_cdk/sources/declarative/manifest_declarative_source.py index 78aeac23f..efc779464 100644 --- a/airbyte_cdk/sources/declarative/manifest_declarative_source.py +++ b/airbyte_cdk/sources/declarative/manifest_declarative_source.py @@ -365,6 +365,11 @@ def _dynamic_stream_configs( # Ensure that each stream is created with a unique name name = dynamic_stream.get("name") + if not isinstance(name, str): + raise ValueError( + f"Expected stream name {name} to be a string, got {type(name)}." + ) + if name in seen_dynamic_streams: error_message = f"Dynamic streams list contains a duplicate name: {name}. Please contact Airbyte Support." failure_type = FailureType.system_error diff --git a/unit_tests/sources/declarative/resolvers/test_http_components_resolver.py b/unit_tests/sources/declarative/resolvers/test_http_components_resolver.py index f09ede0d6..357dcceef 100644 --- a/unit_tests/sources/declarative/resolvers/test_http_components_resolver.py +++ b/unit_tests/sources/declarative/resolvers/test_http_components_resolver.py @@ -3,6 +3,7 @@ # import json +from copy import deepcopy from unittest.mock import MagicMock import pytest @@ -362,6 +363,34 @@ def test_http_components_resolver( assert result == expected_result +def test_wrong_stream_name_type(): + with HttpMocker() as http_mocker: + http_mocker.get( + HttpRequest(url="https://api.test.com/int_items"), + HttpResponse( + body=json.dumps( + [ + {"id": 1, "name": 1}, + {"id": 2, "name": 2}, + ] + ) + ), + ) + + manifest = deepcopy(_MANIFEST) + manifest["dynamic_streams"][0]["components_resolver"]["retriever"]["requester"]["path"] = ( + "int_items" + ) + + source = ConcurrentDeclarativeSource( + source_config=manifest, config=_CONFIG, catalog=None, state=None + ) + with pytest.raises(ValueError) as exc_info: + source.discover(logger=source.logger, config=_CONFIG) + + assert str(exc_info.value) == "Expected stream name 1 to be a string, got ." + + @pytest.mark.parametrize( "components_mapping, retriever_data, stream_template_config, expected_result", [ From ef97304f45f3c54ac3f439dd45471aeb8272ea0c Mon Sep 17 00:00:00 2001 From: Maxime Carbonneau-Leclerc <3360483+maxi297@users.noreply.github.com> Date: Sat, 1 Feb 2025 19:35:28 -0500 Subject: [PATCH 08/12] feat(low-code): improve logging on async retriever errors (#307) --- .../sources/declarative/async_job/job_orchestrator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py b/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py index 398cee9ff..bb8fb85f8 100644 --- a/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py +++ b/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py @@ -437,10 +437,10 @@ def create_and_get_completed_partitions(self) -> Iterable[AsyncPartition]: yield from self._process_running_partitions_and_yield_completed_ones() self._wait_on_status_update() except Exception as exception: + LOGGER.warning( + f"Caught exception that stops the processing of the jobs: {exception}. Traceback: {traceback.format_exc()}" + ) if self._is_breaking_exception(exception): - LOGGER.warning( - f"Caught exception that stops the processing of the jobs: {exception}" - ) self._abort_all_running_jobs() raise exception From 979598cfcf653c43993cf2dbb4c03f0ef4154c78 Mon Sep 17 00:00:00 2001 From: Christo Grabowski <108154848+ChristoGrab@users.noreply.github.com> Date: Tue, 4 Feb 2025 14:33:21 -0500 Subject: [PATCH 09/12] chore: use python 3.11 for connector tests in CI (#313) --- .github/workflows/connector-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/connector-tests.yml b/.github/workflows/connector-tests.yml index 4f6cedee0..ca2521f97 100644 --- a/.github/workflows/connector-tests.yml +++ b/.github/workflows/connector-tests.yml @@ -128,7 +128,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: "3.10" + python-version: "3.11" # Create initial pending status for test report - name: Create Pending Test Report Status if: steps.no_changes.outputs.status != 'cancelled' From 126e2339add6c8cab2b9e9e8218645e96c6936c4 Mon Sep 17 00:00:00 2001 From: Christo Grabowski <108154848+ChristoGrab@users.noreply.github.com> Date: Tue, 4 Feb 2025 15:05:59 -0500 Subject: [PATCH 10/12] feat: enable handling of nested fields when injecting request_option in request body_json (#201) --- airbyte_cdk/sources/declarative/auth/token.py | 11 +- .../declarative_component_schema.yaml | 18 +- .../incremental/datetime_based_cursor.py | 13 +- .../models/declarative_component_schema.py | 14 +- .../parsers/model_to_component_factory.py | 64 +++--- .../list_partition_router.py | 6 +- .../substream_partition_router.py | 12 +- .../declarative/requesters/http_requester.py | 6 +- .../paginators/default_paginator.py | 11 +- .../declarative/requesters/request_option.py | 87 +++++++- ...datetime_based_request_options_provider.py | 13 +- .../retrievers/simple_retriever.py | 5 +- airbyte_cdk/utils/mapping_helpers.py | 113 +++++++--- .../declarative/auth/test_token_auth.py | 35 +++ .../incremental/test_datetime_based_cursor.py | 41 +++- .../test_model_to_component_factory.py | 30 +-- .../paginators/test_default_paginator.py | 4 +- .../request_options/test_request_options.py | 199 ++++++++++++++++++ .../requesters/test_http_requester.py | 4 +- .../retrievers/test_simple_retriever.py | 39 +++- .../test_manifest_declarative_source.py | 4 +- unit_tests/utils/test_mapping_helpers.py | 160 +++++++++----- 22 files changed, 707 insertions(+), 182 deletions(-) create mode 100644 unit_tests/sources/declarative/requesters/request_options/test_request_options.py diff --git a/airbyte_cdk/sources/declarative/auth/token.py b/airbyte_cdk/sources/declarative/auth/token.py index 12fb899b9..caecf9d2c 100644 --- a/airbyte_cdk/sources/declarative/auth/token.py +++ b/airbyte_cdk/sources/declarative/auth/token.py @@ -5,7 +5,7 @@ import base64 import logging from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Union +from typing import Any, Mapping, MutableMapping, Union import requests from cachetools import TTLCache, cached @@ -45,11 +45,6 @@ class ApiKeyAuthenticator(DeclarativeAuthenticator): config: Config parameters: InitVar[Mapping[str, Any]] - def __post_init__(self, parameters: Mapping[str, Any]) -> None: - self._field_name = InterpolatedString.create( - self.request_option.field_name, parameters=parameters - ) - @property def auth_header(self) -> str: options = self._get_request_options(RequestOptionType.header) @@ -60,9 +55,9 @@ def token(self) -> str: return self.token_provider.get_token() def _get_request_options(self, option_type: RequestOptionType) -> Mapping[str, Any]: - options = {} + options: MutableMapping[str, Any] = {} if self.request_option.inject_into == option_type: - options[self._field_name.eval(self.config)] = self.token + self.request_option.inject_into_request(options, self.token, self.config) return options def get_request_params(self) -> Mapping[str, Any]: diff --git a/airbyte_cdk/sources/declarative/declarative_component_schema.yaml b/airbyte_cdk/sources/declarative/declarative_component_schema.yaml index d51d4c922..072a1efcd 100644 --- a/airbyte_cdk/sources/declarative/declarative_component_schema.yaml +++ b/airbyte_cdk/sources/declarative/declarative_component_schema.yaml @@ -2847,25 +2847,35 @@ definitions: enum: [RequestPath] RequestOption: title: Request Option - description: Specifies the key field and where in the request a component's value should be injected. + description: Specifies the key field or path and where in the request a component's value should be injected. type: object required: - type - - field_name - inject_into properties: type: type: string enum: [RequestOption] field_name: - title: Request Option - description: Configures which key should be used in the location that the descriptor is being injected into + title: Field Name + description: Configures which key should be used in the location that the descriptor is being injected into. We hope to eventually deprecate this field in favor of `field_path` for all request_options, but must currently maintain it for backwards compatibility in the Builder. type: string examples: - segment_id interpolation_context: - config - parameters + field_path: + title: Field Path + description: Configures a path to be used for nested structures in JSON body requests (e.g. GraphQL queries) + type: array + items: + type: string + examples: + - ["data", "viewer", "id"] + interpolation_context: + - config + - parameters inject_into: title: Inject Into description: Configures where the descriptor should be set on the HTTP requests. Note that request parameters that are already encoded in the URL path will not be duplicated. diff --git a/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py b/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py index d6d329aec..8ef1c89a4 100644 --- a/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py +++ b/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py @@ -365,14 +365,15 @@ def _get_request_options( options: MutableMapping[str, Any] = {} if not stream_slice: return options + if self.start_time_option and self.start_time_option.inject_into == option_type: - options[self.start_time_option.field_name.eval(config=self.config)] = stream_slice.get( # type: ignore # field_name is always casted to an interpolated string - self._partition_field_start.eval(self.config) - ) + start_time_value = stream_slice.get(self._partition_field_start.eval(self.config)) + self.start_time_option.inject_into_request(options, start_time_value, self.config) + if self.end_time_option and self.end_time_option.inject_into == option_type: - options[self.end_time_option.field_name.eval(config=self.config)] = stream_slice.get( # type: ignore [union-attr] - self._partition_field_end.eval(self.config) - ) + end_time_value = stream_slice.get(self._partition_field_end.eval(self.config)) + self.end_time_option.inject_into_request(options, end_time_value, self.config) + return options def should_be_synced(self, record: Record) -> bool: diff --git a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py index 6aa1d35a7..fe29cee2c 100644 --- a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py +++ b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py @@ -1200,11 +1200,17 @@ class InjectInto(Enum): class RequestOption(BaseModel): type: Literal["RequestOption"] - field_name: str = Field( - ..., - description="Configures which key should be used in the location that the descriptor is being injected into", + field_name: Optional[str] = Field( + None, + description="Configures which key should be used in the location that the descriptor is being injected into. We hope to eventually deprecate this field in favor of `field_path` for all request_options, but must currently maintain it for backwards compatibility in the Builder.", examples=["segment_id"], - title="Request Option", + title="Field Name", + ) + field_path: Optional[List[str]] = Field( + None, + description="Configures a path to be used for nested structures in JSON body requests (e.g. GraphQL queries)", + examples=[["data", "viewer", "id"]], + title="Field Path", ) inject_into: InjectInto = Field( ..., diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index b8eeca1ec..a664b8530 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -733,8 +733,8 @@ def _json_schema_type_name_to_type(value_type: Optional[ValueType]) -> Optional[ } return names_to_types[value_type] - @staticmethod def create_api_key_authenticator( + self, model: ApiKeyAuthenticatorModel, config: Config, token_provider: Optional[TokenProvider] = None, @@ -756,10 +756,8 @@ def create_api_key_authenticator( ) request_option = ( - RequestOption( - inject_into=RequestOptionType(model.inject_into.inject_into.value), - field_name=model.inject_into.field_name, - parameters=model.parameters or {}, + self._create_component_from_model( + model.inject_into, config, parameters=model.parameters or {} ) if model.inject_into else RequestOption( @@ -768,6 +766,7 @@ def create_api_key_authenticator( parameters=model.parameters or {}, ) ) + return ApiKeyAuthenticator( token_provider=( token_provider @@ -849,7 +848,7 @@ def create_session_token_authenticator( token_provider=token_provider, ) else: - return ModelToComponentFactory.create_api_key_authenticator( + return self.create_api_key_authenticator( ApiKeyAuthenticatorModel( type="ApiKeyAuthenticator", api_token="", @@ -1489,19 +1488,15 @@ def create_datetime_based_cursor( ) end_time_option = ( - RequestOption( - inject_into=RequestOptionType(model.end_time_option.inject_into.value), - field_name=model.end_time_option.field_name, - parameters=model.parameters or {}, + self._create_component_from_model( + model.end_time_option, config, parameters=model.parameters or {} ) if model.end_time_option else None ) start_time_option = ( - RequestOption( - inject_into=RequestOptionType(model.start_time_option.inject_into.value), - field_name=model.start_time_option.field_name, - parameters=model.parameters or {}, + self._create_component_from_model( + model.start_time_option, config, parameters=model.parameters or {} ) if model.start_time_option else None @@ -1572,19 +1567,15 @@ def create_declarative_stream( cursor_model = model.incremental_sync end_time_option = ( - RequestOption( - inject_into=RequestOptionType(cursor_model.end_time_option.inject_into.value), - field_name=cursor_model.end_time_option.field_name, - parameters=cursor_model.parameters or {}, + self._create_component_from_model( + cursor_model.end_time_option, config, parameters=cursor_model.parameters or {} ) if cursor_model.end_time_option else None ) start_time_option = ( - RequestOption( - inject_into=RequestOptionType(cursor_model.start_time_option.inject_into.value), - field_name=cursor_model.start_time_option.field_name, - parameters=cursor_model.parameters or {}, + self._create_component_from_model( + cursor_model.start_time_option, config, parameters=cursor_model.parameters or {} ) if cursor_model.start_time_option else None @@ -2150,16 +2141,11 @@ def create_jwt_authenticator( additional_jwt_payload=model.additional_jwt_payload, ) - @staticmethod def create_list_partition_router( - model: ListPartitionRouterModel, config: Config, **kwargs: Any + self, model: ListPartitionRouterModel, config: Config, **kwargs: Any ) -> ListPartitionRouter: request_option = ( - RequestOption( - inject_into=RequestOptionType(model.request_option.inject_into.value), - field_name=model.request_option.field_name, - parameters=model.parameters or {}, - ) + self._create_component_from_model(model.request_option, config) if model.request_option else None ) @@ -2355,7 +2341,25 @@ def create_request_option( model: RequestOptionModel, config: Config, **kwargs: Any ) -> RequestOption: inject_into = RequestOptionType(model.inject_into.value) - return RequestOption(field_name=model.field_name, inject_into=inject_into, parameters={}) + field_path: Optional[List[Union[InterpolatedString, str]]] = ( + [ + InterpolatedString.create(segment, parameters=kwargs.get("parameters", {})) + for segment in model.field_path + ] + if model.field_path + else None + ) + field_name = ( + InterpolatedString.create(model.field_name, parameters=kwargs.get("parameters", {})) + if model.field_name + else None + ) + return RequestOption( + field_name=field_name, + field_path=field_path, + inject_into=inject_into, + parameters=kwargs.get("parameters", {}), + ) def create_record_selector( self, diff --git a/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py b/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py index 29b700b04..6049cefe2 100644 --- a/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py +++ b/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py @@ -3,7 +3,7 @@ # from dataclasses import InitVar, dataclass -from typing import Any, Iterable, List, Mapping, Optional, Union +from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter @@ -100,7 +100,9 @@ def _get_request_option( ): slice_value = stream_slice.get(self._cursor_field.eval(self.config)) if slice_value: - return {self.request_option.field_name.eval(self.config): slice_value} # type: ignore # field_name is always casted to InterpolatedString + options: MutableMapping[str, Any] = {} + self.request_option.inject_into_request(options, slice_value, self.config) + return options else: return {} else: diff --git a/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py b/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py index c242215ea..6ccb055e8 100644 --- a/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py +++ b/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py @@ -4,7 +4,7 @@ import copy import logging from dataclasses import InitVar, dataclass -from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, MutableMapping, Optional, Union import dpath @@ -118,7 +118,7 @@ def get_request_body_json( def _get_request_option( self, option_type: RequestOptionType, stream_slice: Optional[StreamSlice] ) -> Mapping[str, Any]: - params = {} + params: MutableMapping[str, Any] = {} if stream_slice: for parent_config in self.parent_stream_configs: if ( @@ -128,13 +128,7 @@ def _get_request_option( key = parent_config.partition_field.eval(self.config) # type: ignore # partition_field is always casted to an interpolated string value = stream_slice.get(key) if value: - params.update( - { - parent_config.request_option.field_name.eval( # type: ignore [union-attr] - config=self.config - ): value - } - ) + parent_config.request_option.inject_into_request(params, value, self.config) return params def stream_slices(self) -> Iterable[StreamSlice]: diff --git a/airbyte_cdk/sources/declarative/requesters/http_requester.py b/airbyte_cdk/sources/declarative/requesters/http_requester.py index 35d4b0f11..ad23f4d06 100644 --- a/airbyte_cdk/sources/declarative/requesters/http_requester.py +++ b/airbyte_cdk/sources/declarative/requesters/http_requester.py @@ -199,6 +199,9 @@ def _get_request_options( Raise a ValueError if there's a key collision Returned merged mapping otherwise """ + + is_body_json = requester_method.__name__ == "get_request_body_json" + return combine_mappings( [ requester_method( @@ -208,7 +211,8 @@ def _get_request_options( ), auth_options_method(), extra_options, - ] + ], + allow_same_value_merge=is_body_json, ) def _request_headers( diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py b/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py index 59255c75b..6fb412cd9 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py @@ -187,7 +187,7 @@ def get_request_body_json( def _get_request_options( self, option_type: RequestOptionType, next_page_token: Optional[Mapping[str, Any]] ) -> MutableMapping[str, Any]: - options = {} + options: MutableMapping[str, Any] = {} token = next_page_token.get("next_page_token") if next_page_token else None if ( @@ -196,15 +196,16 @@ def _get_request_options( and isinstance(self.page_token_option, RequestOption) and self.page_token_option.inject_into == option_type ): - options[self.page_token_option.field_name.eval(config=self.config)] = token # type: ignore # field_name is always cast to an interpolated string + self.page_token_option.inject_into_request(options, token, self.config) + if ( self.page_size_option and self.pagination_strategy.get_page_size() and self.page_size_option.inject_into == option_type ): - options[self.page_size_option.field_name.eval(config=self.config)] = ( # type: ignore [union-attr] - self.pagination_strategy.get_page_size() - ) # type: ignore # field_name is always cast to an interpolated string + page_size = self.pagination_strategy.get_page_size() + self.page_size_option.inject_into_request(options, page_size, self.config) + return options diff --git a/airbyte_cdk/sources/declarative/requesters/request_option.py b/airbyte_cdk/sources/declarative/requesters/request_option.py index d13d20566..e0946b53b 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_option.py +++ b/airbyte_cdk/sources/declarative/requesters/request_option.py @@ -4,9 +4,10 @@ from dataclasses import InitVar, dataclass from enum import Enum -from typing import Any, Mapping, Union +from typing import Any, List, Literal, Mapping, MutableMapping, Optional, Union from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString +from airbyte_cdk.sources.types import Config class RequestOptionType(Enum): @@ -26,13 +27,91 @@ class RequestOption: Describes an option to set on a request Attributes: - field_name (str): Describes the name of the parameter to inject + field_name (str): Describes the name of the parameter to inject. Mutually exclusive with field_path. + field_path (list(str)): Describes the path to a nested field as a list of field names. + Only valid for body_json injection type, and mutually exclusive with field_name. inject_into (RequestOptionType): Describes where in the HTTP request to inject the parameter """ - field_name: Union[InterpolatedString, str] inject_into: RequestOptionType parameters: InitVar[Mapping[str, Any]] + field_name: Optional[Union[InterpolatedString, str]] = None + field_path: Optional[List[Union[InterpolatedString, str]]] = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: - self.field_name = InterpolatedString.create(self.field_name, parameters=parameters) + # Validate inputs. We should expect either field_name or field_path, but not both + if self.field_name is None and self.field_path is None: + raise ValueError("RequestOption requires either a field_name or field_path") + + if self.field_name is not None and self.field_path is not None: + raise ValueError( + "Only one of field_name or field_path can be provided to RequestOption" + ) + + # Nested field injection is only supported for body JSON injection + if self.field_path is not None and self.inject_into != RequestOptionType.body_json: + raise ValueError( + "Nested field injection is only supported for body JSON injection. Please use a top-level field_name for other injection types." + ) + + # Convert field_name and field_path into InterpolatedString objects if they are strings + if self.field_name is not None: + self.field_name = InterpolatedString.create(self.field_name, parameters=parameters) + elif self.field_path is not None: + self.field_path = [ + InterpolatedString.create(segment, parameters=parameters) + for segment in self.field_path + ] + + @property + def _is_field_path(self) -> bool: + """Returns whether this option is a field path (ie, a nested field)""" + return self.field_path is not None + + def inject_into_request( + self, + target: MutableMapping[str, Any], + value: Any, + config: Config, + ) -> None: + """ + Inject a request option value into a target request structure using either field_name or field_path. + For non-body-json injection, only top-level field names are supported. + For body-json injection, both field names and nested field paths are supported. + + Args: + target: The request structure to inject the value into + value: The value to inject + config: The config object to use for interpolation + """ + if self._is_field_path: + if self.inject_into != RequestOptionType.body_json: + raise ValueError( + "Nested field injection is only supported for body JSON injection. Please use a top-level field_name for other injection types." + ) + + assert self.field_path is not None # for type checker + current = target + # Convert path segments into strings, evaluating any interpolated segments + # Example: ["data", "{{ config[user_type] }}", "id"] -> ["data", "admin", "id"] + *path_parts, final_key = [ + str( + segment.eval(config=config) + if isinstance(segment, InterpolatedString) + else segment + ) + for segment in self.field_path + ] + + # Build a nested dictionary structure and set the final value at the deepest level + for part in path_parts: + current = current.setdefault(part, {}) + current[final_key] = value + else: + # For non-nested fields, evaluate the field name if it's an interpolated string + key = ( + self.field_name.eval(config=config) + if isinstance(self.field_name, InterpolatedString) + else self.field_name + ) + target[str(key)] = value diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/datetime_based_request_options_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/datetime_based_request_options_provider.py index 05e06db71..437ea7b7b 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_options/datetime_based_request_options_provider.py +++ b/airbyte_cdk/sources/declarative/requesters/request_options/datetime_based_request_options_provider.py @@ -80,12 +80,13 @@ def _get_request_options( options: MutableMapping[str, Any] = {} if not stream_slice: return options + if self.start_time_option and self.start_time_option.inject_into == option_type: - options[self.start_time_option.field_name.eval(config=self.config)] = stream_slice.get( # type: ignore # field_name is always casted to an interpolated string - self._partition_field_start.eval(self.config) - ) + start_time_value = stream_slice.get(self._partition_field_start.eval(self.config)) + self.start_time_option.inject_into_request(options, start_time_value, self.config) + if self.end_time_option and self.end_time_option.inject_into == option_type: - options[self.end_time_option.field_name.eval(config=self.config)] = stream_slice.get( # type: ignore [union-attr] - self._partition_field_end.eval(self.config) - ) + end_time_value = stream_slice.get(self._partition_field_end.eval(self.config)) + self.end_time_option.inject_into_request(options, end_time_value, self.config) + return options diff --git a/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py b/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py index 45533ac4b..a5a8a71bc 100644 --- a/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py +++ b/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py @@ -128,6 +128,9 @@ def _get_request_options( Returned merged mapping otherwise """ # FIXME we should eventually remove the usage of stream_state as part of the interpolation + + is_body_json = paginator_method.__name__ == "get_request_body_json" + mappings = [ paginator_method( stream_state=stream_state, @@ -143,7 +146,7 @@ def _get_request_options( next_page_token=next_page_token, ) ) - return combine_mappings(mappings) + return combine_mappings(mappings, allow_same_value_merge=is_body_json) def _request_headers( self, diff --git a/airbyte_cdk/utils/mapping_helpers.py b/airbyte_cdk/utils/mapping_helpers.py index 469fb5e0a..c5682c288 100644 --- a/airbyte_cdk/utils/mapping_helpers.py +++ b/airbyte_cdk/utils/mapping_helpers.py @@ -3,43 +3,102 @@ # -from typing import Any, List, Mapping, Optional, Set, Union +import copy +from typing import Any, Dict, List, Mapping, Optional, Union + + +def _merge_mappings( + target: Dict[str, Any], + source: Mapping[str, Any], + path: Optional[List[str]] = None, + allow_same_value_merge: bool = False, +) -> None: + """ + Recursively merge two dictionaries, raising an error if there are any conflicts. + For body_json requests (allow_same_value_merge=True), a conflict occurs only when the same path has different values. + For other request types (allow_same_value_merge=False), any duplicate key is a conflict, regardless of value. + + Args: + target: The dictionary to merge into + source: The dictionary to merge from + path: The current path in the nested structure (for error messages) + allow_same_value_merge: Whether to allow merging the same value into the same key. Set to false by default, should only be true for body_json injections + """ + path = path or [] + for key, source_value in source.items(): + current_path = path + [str(key)] + + if key in target: + target_value = target[key] + if isinstance(target_value, dict) and isinstance(source_value, dict): + # Only body_json supports nested_structures + if not allow_same_value_merge: + raise ValueError(f"Duplicate keys found: {'.'.join(current_path)}") + # If both are dictionaries, recursively merge them + _merge_mappings(target_value, source_value, current_path, allow_same_value_merge) + + elif not allow_same_value_merge or target_value != source_value: + # If same key has different values, that's a conflict + raise ValueError(f"Duplicate keys found: {'.'.join(current_path)}") + else: + # No conflict, just copy the value (using deepcopy for nested structures) + target[key] = copy.deepcopy(source_value) def combine_mappings( mappings: List[Optional[Union[Mapping[str, Any], str]]], + allow_same_value_merge: bool = False, ) -> Union[Mapping[str, Any], str]: """ - Combine multiple mappings into a single mapping. If any of the mappings are a string, return - that string. Raise errors in the following cases: - * If there are duplicate keys across mappings - * If there are multiple string mappings - * If there are multiple mappings containing keys and one of them is a string + Combine multiple mappings into a single mapping. + + For body_json requests (allow_same_value_merge=True): + - Supports nested structures (e.g., {"data": {"user": {"id": 1}}}) + - Allows duplicate keys if their values match + - Raises error if same path has different values + + For other request types (allow_same_value_merge=False): + - Only supports flat structures + - Any duplicate key raises an error, regardless of value + + Args: + mappings: List of mappings to combine + allow_same_value_merge: Whether to allow duplicate keys with matching values. + Should only be True for body_json requests. + + Returns: + A single mapping combining all inputs, or a string if there is exactly one + string mapping and no other non-empty mappings. + + Raises: + ValueError: If there are: + - Multiple string mappings + - Both a string mapping and non-empty dictionary mappings + - Conflicting keys/paths based on allow_same_value_merge setting """ - all_keys: List[Set[str]] = [] - for part in mappings: - if part is None: - continue - keys = set(part.keys()) if not isinstance(part, str) else set() - all_keys.append(keys) - - string_options = sum(isinstance(mapping, str) for mapping in mappings) - # If more than one mapping is a string, raise a ValueError + if not mappings: + return {} + + # Count how many string options we have, ignoring None values + string_options = sum(isinstance(mapping, str) for mapping in mappings if mapping is not None) if string_options > 1: raise ValueError("Cannot combine multiple string options") - if string_options == 1 and sum(len(keys) for keys in all_keys) > 0: - raise ValueError("Cannot combine multiple options if one is a string") + # Filter out None values and empty mappings + non_empty_mappings = [ + m for m in mappings if m is not None and not (isinstance(m, Mapping) and not m) + ] - # If any mapping is a string, return it - for mapping in mappings: - if isinstance(mapping, str): - return mapping + # If there is only one string option and no other non-empty mappings, return it + if string_options == 1: + if len(non_empty_mappings) > 1: + raise ValueError("Cannot combine multiple options if one is a string") + return next(m for m in non_empty_mappings if isinstance(m, str)) - # If there are duplicate keys across mappings, raise a ValueError - intersection = set().union(*all_keys) - if len(intersection) < sum(len(keys) for keys in all_keys): - raise ValueError(f"Duplicate keys found: {intersection}") + # Start with an empty result and merge each mapping into it + result: Dict[str, Any] = {} + for mapping in non_empty_mappings: + if mapping and isinstance(mapping, Mapping): + _merge_mappings(result, mapping, allow_same_value_merge=allow_same_value_merge) - # Return the combined mappings - return {key: value for mapping in mappings if mapping for key, value in mapping.items()} # type: ignore # mapping can't be string here + return result diff --git a/unit_tests/sources/declarative/auth/test_token_auth.py b/unit_tests/sources/declarative/auth/test_token_auth.py index 2a23d4e19..4e90367c1 100644 --- a/unit_tests/sources/declarative/auth/test_token_auth.py +++ b/unit_tests/sources/declarative/auth/test_token_auth.py @@ -248,3 +248,38 @@ def test_api_key_authenticator_inject( parameters=parameters, ) assert {expected_field_name: expected_field_value} == getattr(token_auth, validation_fn)() + + +@pytest.mark.parametrize( + "field_path, token, expected_result", + [ + ( + ["data", "auth", "token"], + "test-token", + {"data": {"auth": {"token": "test-token"}}}, + ), + ( + ["api", "{{ config.api_version }}", "auth", "token"], + "test-token", + {"api": {"v2": {"auth": {"token": "test-token"}}}}, + ), + ], + ids=["Basic nested structure", "Nested with config interpolation"], +) +def test_api_key_authenticator_nested_token_injection(field_path, token, expected_result): + """Test that the ApiKeyAuthenticator can properly inject tokens into nested structures when using body_json""" + config = {"api_version": "v2"} + parameters = {"auth_type": "bearer"} + + token_provider = InterpolatedStringTokenProvider( + config=config, api_token=token, parameters=parameters + ) + token_auth = ApiKeyAuthenticator( + request_option=RequestOption( + inject_into=RequestOptionType.body_json, field_path=field_path, parameters=parameters + ), + token_provider=token_provider, + config=config, + parameters=parameters, + ) + assert token_auth.get_request_body_json() == expected_result diff --git a/unit_tests/sources/declarative/incremental/test_datetime_based_cursor.py b/unit_tests/sources/declarative/incremental/test_datetime_based_cursor.py index 37ed7ebfe..3ddc5847f 100644 --- a/unit_tests/sources/declarative/incremental/test_datetime_based_cursor.py +++ b/unit_tests/sources/declarative/incremental/test_datetime_based_cursor.py @@ -782,13 +782,14 @@ def test_given_different_format_and_slice_is_highest_when_close_slice_then_state @pytest.mark.parametrize( - "test_name, inject_into, field_name, expected_req_params, expected_headers, expected_body_json, expected_body_data", + "test_name, inject_into, field_name, field_path, expected_req_params, expected_headers, expected_body_json, expected_body_data", [ - ("test_start_time_inject_into_none", None, None, {}, {}, {}, {}), + ("test_start_time_inject_into_none", None, None, None, {}, {}, {}, {}), ( "test_start_time_passed_by_req_param", RequestOptionType.request_parameter, "start_time", + None, { "start_time": "2021-01-01T00:00:00.000000+0000", "endtime": "2021-01-04T00:00:00.000000+0000", @@ -801,6 +802,7 @@ def test_given_different_format_and_slice_is_highest_when_close_slice_then_state "test_start_time_inject_into_header", RequestOptionType.header, "start_time", + None, {}, { "start_time": "2021-01-01T00:00:00.000000+0000", @@ -810,9 +812,10 @@ def test_given_different_format_and_slice_is_highest_when_close_slice_then_state {}, ), ( - "test_start_time_inject_intoy_body_json", + "test_start_time_inject_into_body_json", RequestOptionType.body_json, "start_time", + None, {}, {}, { @@ -821,10 +824,30 @@ def test_given_different_format_and_slice_is_highest_when_close_slice_then_state }, {}, ), + ( + "test_nested_field_injection_into_body_json", + RequestOptionType.body_json, + None, + ["data", "queries", "time_range", "start"], + {}, + {}, + { + "data": { + "queries": { + "time_range": { + "start": "2021-01-01T00:00:00.000000+0000", + "end": "2021-01-04T00:00:00.000000+0000", + } + } + } + }, + {}, + ), ( "test_start_time_inject_into_body_data", RequestOptionType.body_data, "start_time", + None, {}, {}, {}, @@ -839,18 +862,26 @@ def test_request_option( test_name, inject_into, field_name, + field_path, expected_req_params, expected_headers, expected_body_json, expected_body_data, ): start_request_option = ( - RequestOption(inject_into=inject_into, parameters={}, field_name=field_name) + RequestOption( + inject_into=inject_into, parameters={}, field_name=field_name, field_path=field_path + ) if inject_into else None ) end_request_option = ( - RequestOption(inject_into=inject_into, parameters={}, field_name="endtime") + RequestOption( + inject_into=inject_into, + parameters={}, + field_name="endtime" if field_name else None, + field_path=["data", "queries", "time_range", "end"] if field_path else None, + ) if inject_into else None ) diff --git a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py index e489a8526..43564a5c8 100644 --- a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py +++ b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py @@ -609,8 +609,8 @@ def test_list_based_stream_slicer_with_values_defined_in_config(): cursor_field: repository request_option: type: RequestOption - inject_into: header - field_name: repository + inject_into: body_json + field_path: ["repository", "id"] """ parsed_manifest = YamlDeclarativeSource._parse(content) resolved_manifest = resolver.preprocess_manifest(parsed_manifest) @@ -626,8 +626,10 @@ def test_list_based_stream_slicer_with_values_defined_in_config(): assert isinstance(partition_router, ListPartitionRouter) assert partition_router.values == ["airbyte", "airbyte-cloud"] - assert partition_router.request_option.inject_into == RequestOptionType.header - assert partition_router.request_option.field_name.eval(config=input_config) == "repository" + assert partition_router.request_option.inject_into == RequestOptionType.body_json + for field in partition_router.request_option.field_path: + assert isinstance(field, InterpolatedString) + assert len(partition_router.request_option.field_path) == 2 def test_create_substream_partition_router(): @@ -730,7 +732,7 @@ def test_datetime_based_cursor(): end_time_option: type: RequestOption inject_into: body_json - field_name: "before_{{ parameters['cursor_field'] }}" + field_path: ["before_{{ parameters['cursor_field'] }}"] partition_field_start: star partition_field_end: en """ @@ -759,7 +761,9 @@ def test_datetime_based_cursor(): == "since_updated_at" ) assert stream_slicer.end_time_option.inject_into == RequestOptionType.body_json - assert stream_slicer.end_time_option.field_name.eval({}) == "before_created_at" + assert [field.eval({}) for field in stream_slicer.end_time_option.field_path] == [ + "before_created_at" + ] assert stream_slicer._partition_field_start.eval({}) == "star" assert stream_slicer._partition_field_end.eval({}) == "en" @@ -920,8 +924,8 @@ def test_resumable_full_refresh_stream(): type: DefaultPaginator page_size_option: type: RequestOption - inject_into: request_parameter - field_name: page_size + inject_into: body_json + field_path: ["variables", "page_size"] page_token_option: type: RequestPath pagination_strategy: @@ -1019,11 +1023,10 @@ def test_resumable_full_refresh_stream(): assert isinstance(stream.retriever.paginator, DefaultPaginator) assert isinstance(stream.retriever.paginator.decoder, PaginationDecoderDecorator) - assert stream.retriever.paginator.page_size_option.field_name.eval(input_config) == "page_size" - assert ( - stream.retriever.paginator.page_size_option.inject_into - == RequestOptionType.request_parameter - ) + for string in stream.retriever.paginator.page_size_option.field_path: + assert isinstance(string, InterpolatedString) + assert len(stream.retriever.paginator.page_size_option.field_path) == 2 + assert stream.retriever.paginator.page_size_option.inject_into == RequestOptionType.body_json assert isinstance(stream.retriever.paginator.page_token_option, RequestPath) assert stream.retriever.paginator.url_base.string == "https://api.sendgrid.com/v3/" assert stream.retriever.paginator.url_base.default == "https://api.sendgrid.com/v3/" @@ -2525,7 +2528,6 @@ def test_merge_incremental_and_partition_router(incremental, partition_router, e assert isinstance(stream, DeclarativeStream) assert isinstance(stream.retriever, SimpleRetriever) - print(stream.retriever.stream_slicer) assert isinstance(stream.retriever.stream_slicer, expected_type) if incremental and partition_router: diff --git a/unit_tests/sources/declarative/requesters/paginators/test_default_paginator.py b/unit_tests/sources/declarative/requesters/paginators/test_default_paginator.py index cbe185a37..57b6d9d34 100644 --- a/unit_tests/sources/declarative/requesters/paginators/test_default_paginator.py +++ b/unit_tests/sources/declarative/requesters/paginators/test_default_paginator.py @@ -437,7 +437,9 @@ def test_paginator_with_page_option_no_page_size(): DefaultPaginator( page_size_option=MagicMock(), page_token_option=RequestOption( - "limit", RequestOptionType.request_parameter, parameters={} + field_name="limit", + inject_into=RequestOptionType.request_parameter, + parameters={}, ), pagination_strategy=pagination_strategy, config=MagicMock(), diff --git a/unit_tests/sources/declarative/requesters/request_options/test_request_options.py b/unit_tests/sources/declarative/requesters/request_options/test_request_options.py new file mode 100644 index 000000000..115ce688d --- /dev/null +++ b/unit_tests/sources/declarative/requesters/request_options/test_request_options.py @@ -0,0 +1,199 @@ +from typing import Any, Dict, List, Optional, Type + +import pytest + +from airbyte_cdk.sources.declarative.requesters.request_option import ( + RequestOption, + RequestOptionType, +) + + +@pytest.mark.parametrize( + "field_name, field_path, inject_into, error_type, error_message", + [ + ( + None, + None, + RequestOptionType.body_json, + ValueError, + "RequestOption requires either a field_name or field_path", + ), + ( + "field", + ["data", "field"], + RequestOptionType.body_json, + ValueError, + "Only one of field_name or field_path can be provided", + ), + ( + None, + ["data", "field"], + RequestOptionType.header, + ValueError, + "Nested field injection is only supported for body JSON injection.", + ), + ], +) +def test_request_option_validation( + field_name: Optional[str], + field_path: Any, + inject_into: RequestOptionType, + error_type: Type[Exception], + error_message: str, +): + """Test various validation cases for RequestOption""" + with pytest.raises(error_type, match=error_message): + RequestOption( + field_name=field_name, field_path=field_path, inject_into=inject_into, parameters={} + ) + + +@pytest.mark.parametrize( + "request_option_args, value, expected_result", + [ + # Basic field_name test + ( + { + "field_name": "test_{{ config['base_field'] }}", + "inject_into": RequestOptionType.body_json, + }, + "test_value", + {"test_value": "test_value"}, + ), + # Basic field_path test + ( + { + "field_path": ["data", "nested_{{ config['base_field'] }}", "field"], + "inject_into": RequestOptionType.body_json, + }, + "test_value", + {"data": {"nested_value": {"field": "test_value"}}}, + ), + # Deep nesting test + ( + { + "field_path": ["level1", "level2", "level3", "level4", "field"], + "inject_into": RequestOptionType.body_json, + }, + "deep_value", + {"level1": {"level2": {"level3": {"level4": {"field": "deep_value"}}}}}, + ), + ], +) +def test_inject_into_request_cases( + request_option_args: Dict[str, Any], value: Any, expected_result: Dict[str, Any] +): + """Test various injection cases""" + config = {"base_field": "value"} + target: Dict[str, Any] = {} + + request_option = RequestOption(**request_option_args, parameters={}) + request_option.inject_into_request(target, value, config) + assert target == expected_result + + +@pytest.mark.parametrize( + "config, parameters, field_path, expected_structure", + [ + ( + {"nested": "user"}, + {"type": "profile"}, + ["data", "{{ config['nested'] }}", "{{ parameters['type'] }}"], + {"data": {"user": {"profile": "test_value"}}}, + ), + ( + {"user_type": "admin", "section": "profile"}, + {"id": "12345"}, + [ + "data", + "{{ config['user_type'] }}", + "{{ parameters['id'] }}", + "{{ config['section'] }}", + "details", + ], + {"data": {"admin": {"12345": {"profile": {"details": "test_value"}}}}}, + ), + ], +) +def test_interpolation_cases( + config: Dict[str, Any], + parameters: Dict[str, Any], + field_path: List[str], + expected_structure: Dict[str, Any], +): + """Test various interpolation scenarios""" + request_option = RequestOption( + field_path=field_path, inject_into=RequestOptionType.body_json, parameters=parameters + ) + target: Dict[str, Any] = {} + request_option.inject_into_request(target, "test_value", config) + assert target == expected_structure + + +@pytest.mark.parametrize( + "value, expected_type", + [ + (42, int), + (3.14, float), + (True, bool), + (["a", "b", "c"], list), + ({"key": "value"}, dict), + (None, type(None)), + ], +) +def test_value_type_handling(value: Any, expected_type: Type): + """Test handling of different value types""" + config = {} + target: Dict[str, Any] = {} + request_option = RequestOption( + field_path=["data", "test"], inject_into=RequestOptionType.body_json, parameters={} + ) + request_option.inject_into_request(target, value, config) + assert isinstance(target["data"]["test"], expected_type) + assert target["data"]["test"] == value + + +@pytest.mark.parametrize( + "field_name, field_path, inject_into, expected__is_field_path", + [ + ("field", None, RequestOptionType.body_json, False), + (None, ["data", "field"], RequestOptionType.body_json, True), + ], +) +def test__is_field_path( + field_name: Optional[str], + field_path: Optional[List[str]], + inject_into: RequestOptionType, + expected__is_field_path: bool, +): + """Test the _is_field_path property""" + request_option = RequestOption( + field_name=field_name, field_path=field_path, inject_into=inject_into, parameters={} + ) + assert request_option._is_field_path == expected__is_field_path + + +def test_multiple_injections(): + """Test injecting multiple values into the same target dict""" + config = {"base": "test"} + target = {"existing": "value"} + + # First injection with field_name + option1 = RequestOption( + field_name="field1", inject_into=RequestOptionType.body_json, parameters={} + ) + option1.inject_into_request(target, "value1", config) + + # Second injection with nested path + option2 = RequestOption( + field_path=["data", "nested", "field2"], + inject_into=RequestOptionType.body_json, + parameters={}, + ) + option2.inject_into_request(target, "value2", config) + + assert target == { + "existing": "value", + "field1": "value1", + "data": {"nested": {"field2": "value2"}}, + } diff --git a/unit_tests/sources/declarative/requesters/test_http_requester.py b/unit_tests/sources/declarative/requesters/test_http_requester.py index 8e63aa21e..f02ec206b 100644 --- a/unit_tests/sources/declarative/requesters/test_http_requester.py +++ b/unit_tests/sources/declarative/requesters/test_http_requester.py @@ -279,8 +279,8 @@ def test_basic_send_request(): None, '{"field": "value", "field2": "value", "authfield": "val"}', ), - (None, {"field": "value"}, None, {"field": "value"}, None, None, ValueError, None), - (None, {"field": "value"}, None, None, None, {"field": "value"}, ValueError, None), + (None, {"field": "value"}, None, {"field": "value"}, None, None, None, "field=value"), + (None, {"field": "value"}, None, None, None, {"field": "value"}, None, "field=value"), # raise on mixed data and json params ( {"field": "value"}, diff --git a/unit_tests/sources/declarative/retrievers/test_simple_retriever.py b/unit_tests/sources/declarative/retrievers/test_simple_retriever.py index b33febcaf..fe03c6ad4 100644 --- a/unit_tests/sources/declarative/retrievers/test_simple_retriever.py +++ b/unit_tests/sources/declarative/retrievers/test_simple_retriever.py @@ -58,6 +58,11 @@ def test_simple_retriever_full(mock_http_stream): request_params = {"param": "value"} requester.get_request_params.return_value = request_params + requester.get_request_params.__name__ = "get_request_params" + requester.get_request_headers.__name__ = "get_request_headers" + requester.get_request_body_data.__name__ = "get_request_body_data" + requester.get_request_body_json.__name__ = "get_request_body_json" + paginator = MagicMock() paginator.get_initial_token.return_value = None next_page_token = {"cursor": "cursor_value"} @@ -65,6 +70,11 @@ def test_simple_retriever_full(mock_http_stream): paginator.next_page_token.return_value = next_page_token paginator.get_request_headers.return_value = {} + paginator.get_request_params.__name__ = "get_request_params" + paginator.get_request_headers.__name__ = "get_request_headers" + paginator.get_request_body_data.__name__ = "get_request_body_data" + paginator.get_request_body_json.__name__ = "get_request_body_json" + record_selector = MagicMock() record_selector.select_records.return_value = records @@ -442,11 +452,19 @@ def test_get_request_options_from_pagination( paginator.get_request_body_data.return_value = paginator_mapping paginator.get_request_body_json.return_value = paginator_mapping + paginator.get_request_params.__name__ = "get_request_params" + paginator.get_request_body_data.__name__ = "get_request_body_data" + paginator.get_request_body_json.__name__ = "get_request_body_json" + request_options_provider = MagicMock() request_options_provider.get_request_params.return_value = request_options_provider_mapping request_options_provider.get_request_body_data.return_value = request_options_provider_mapping request_options_provider.get_request_body_json.return_value = request_options_provider_mapping + request_options_provider.get_request_params.__name__ = "get_request_params" + request_options_provider.get_request_body_data.__name__ = "get_request_body_data" + request_options_provider.get_request_body_json.__name__ = "get_request_body_json" + record_selector = MagicMock() retriever = SimpleRetriever( name="stream_name", @@ -489,10 +507,12 @@ def test_get_request_headers(test_name, paginator_mapping, expected_mapping): # This test is separate from the other request options because request headers must be strings paginator = MagicMock() paginator.get_request_headers.return_value = paginator_mapping + paginator.get_request_headers.__name__ = "get_request_headers" requester = MagicMock(use_cache=False) request_option_provider = MagicMock() request_option_provider.get_request_headers.return_value = {"key": "value"} + request_option_provider.get_request_headers.__name__ = "get_request_headers" record_selector = MagicMock() retriever = SimpleRetriever( @@ -565,10 +585,12 @@ def test_ignore_request_option_provider_parameters_on_paginated_requests( # This test is separate from the other request options because request headers must be strings paginator = MagicMock() paginator.get_request_headers.return_value = paginator_mapping + paginator.get_request_headers.__name__ = "get_request_headers" requester = MagicMock(use_cache=False) request_option_provider = MagicMock() request_option_provider.get_request_headers.return_value = {"key_from_slicer": "value"} + request_option_provider.get_request_headers.__name__ = "get_request_headers" record_selector = MagicMock() retriever = SimpleRetriever( @@ -612,6 +634,7 @@ def test_request_body_data( ): paginator = MagicMock() paginator.get_request_body_data.return_value = paginator_body_data + paginator.get_request_body_data.__name__ = "get_request_body_data" requester = MagicMock(use_cache=False) request_option_provider = MagicMock() @@ -825,11 +848,25 @@ def test_emit_log_request_response_messages(mocker): "airbyte_cdk.sources.declarative.retrievers.simple_retriever.format_http_message" ) requester = MagicMock() + + # Add __name__ to mock methods + requester.get_request_params.__name__ = "get_request_params" + requester.get_request_headers.__name__ = "get_request_headers" + requester.get_request_body_data.__name__ = "get_request_body_data" + requester.get_request_body_json.__name__ = "get_request_body_json" + + # The paginator mock also needs __name__ attributes + paginator = MagicMock() + paginator.get_request_params.__name__ = "get_request_params" + paginator.get_request_headers.__name__ = "get_request_headers" + paginator.get_request_body_data.__name__ = "get_request_body_data" + paginator.get_request_body_json.__name__ = "get_request_body_json" + retriever = SimpleRetrieverTestReadDecorator( name="stream_name", primary_key=primary_key, requester=requester, - paginator=MagicMock(), + paginator=paginator, record_selector=record_selector, stream_slicer=SinglePartitionRouter(parameters={}), parameters={}, diff --git a/unit_tests/sources/declarative/test_manifest_declarative_source.py b/unit_tests/sources/declarative/test_manifest_declarative_source.py index e4eed4735..38d6874c0 100644 --- a/unit_tests/sources/declarative/test_manifest_declarative_source.py +++ b/unit_tests/sources/declarative/test_manifest_declarative_source.py @@ -1030,8 +1030,8 @@ def test_manifest_without_at_least_one_stream(self): "page_size": 10, "page_size_option": { "type": "RequestOption", - "inject_into": "request_parameter", - "field_name": "page_size", + "inject_into": "request_body", + "field_path": ["variables", "page_size"], }, "page_token_option": {"type": "RequestPath"}, "pagination_strategy": { diff --git a/unit_tests/utils/test_mapping_helpers.py b/unit_tests/utils/test_mapping_helpers.py index 272ce9b7a..124bf4565 100644 --- a/unit_tests/utils/test_mapping_helpers.py +++ b/unit_tests/utils/test_mapping_helpers.py @@ -1,55 +1,115 @@ -# -# Copyright (c) 2023 Airbyte, Inc., all rights reserved. -# - import pytest from airbyte_cdk.utils.mapping_helpers import combine_mappings -def test_basic_merge(): - mappings = [{"a": 1}, {"b": 2}, {"c": 3}, {}] - result = combine_mappings(mappings) - assert result == {"a": 1, "b": 2, "c": 3} - - -def test_combine_with_string(): - mappings = [{"a": 1}, "option"] - with pytest.raises(ValueError, match="Cannot combine multiple options if one is a string"): - combine_mappings(mappings) - - -def test_overlapping_keys(): - mappings = [{"a": 1, "b": 2}, {"b": 3}] - with pytest.raises(ValueError, match="Duplicate keys found"): - combine_mappings(mappings) - - -def test_multiple_strings(): - mappings = ["option1", "option2"] - with pytest.raises(ValueError, match="Cannot combine multiple string options"): - combine_mappings(mappings) - - -def test_handle_none_values(): - mappings = [{"a": 1}, None, {"b": 2}] - result = combine_mappings(mappings) - assert result == {"a": 1, "b": 2} - - -def test_empty_mappings(): - mappings = [] - result = combine_mappings(mappings) - assert result == {} - - -def test_single_mapping(): - mappings = [{"a": 1}] - result = combine_mappings(mappings) - assert result == {"a": 1} - - -def test_combine_with_string_and_empty_mappings(): - mappings = ["option", {}] - result = combine_mappings(mappings) - assert result == "option" +@pytest.mark.parametrize( + "test_name, mappings, expected_result", + [ + ("empty_mappings", [], {}), + ("single_mapping", [{"a": 1}], {"a": 1}), + ("handle_none_values", [{"a": 1}, None, {"b": 2}], {"a": 1, "b": 2}), + ], +) +def test_basic_functionality(test_name, mappings, expected_result): + """Test basic mapping operations that work the same regardless of request type""" + assert combine_mappings(mappings) == expected_result + + +@pytest.mark.parametrize( + "test_name, mappings, expected_result, expected_error", + [ + ( + "combine_with_string", + [{"a": 1}, "option"], + None, + "Cannot combine multiple options if one is a string", + ), + ( + "multiple_strings", + ["option1", "option2"], + None, + "Cannot combine multiple string options", + ), + ("string_with_empty_mapping", ["option", {}], "option", None), + ], +) +def test_string_handling(test_name, mappings, expected_result, expected_error): + """Test string handling behavior which is independent of request type""" + if expected_error: + with pytest.raises(ValueError, match=expected_error): + combine_mappings(mappings) + else: + assert combine_mappings(mappings) == expected_result + + +@pytest.mark.parametrize( + "test_name, mappings, expected_error", + [ + ("duplicate_keys_same_value", [{"a": 1}, {"a": 1}], "Duplicate keys found"), + ("duplicate_keys_different_value", [{"a": 1}, {"a": 2}], "Duplicate keys found"), + ( + "nested_structure_not_allowed", + [{"a": {"b": 1}}, {"a": {"c": 2}}], + "Duplicate keys found", + ), + ("any_nesting_not_allowed", [{"a": {"b": 1}}, {"a": {"d": 2}}], "Duplicate keys found"), + ], +) +def test_non_body_json_requests(test_name, mappings, expected_error): + """Test strict validation for non-body-json requests (headers, params, body_data)""" + with pytest.raises(ValueError, match=expected_error): + combine_mappings(mappings, allow_same_value_merge=False) + + +@pytest.mark.parametrize( + "test_name, mappings, expected_result, expected_error", + [ + ( + "simple_nested_merge", + [{"a": {"b": 1}}, {"c": {"d": 2}}], + {"a": {"b": 1}, "c": {"d": 2}}, + None, + ), + ( + "deep_nested_merge", + [{"a": {"b": {"c": 1}}}, {"d": {"e": {"f": 2}}}], + {"a": {"b": {"c": 1}}, "d": {"e": {"f": 2}}}, + None, + ), + ( + "nested_merge_same_level", + [ + {"data": {"user": {"id": 1}, "status": "active"}}, + {"data": {"user": {"name": "test"}, "type": "admin"}}, + ], + { + "data": { + "user": {"id": 1, "name": "test"}, + "status": "active", + "type": "admin", + }, + }, + None, + ), + ( + "nested_conflict", + [{"a": {"b": 1}}, {"a": {"b": 2}}], + None, + "Duplicate keys found", + ), + ( + "type_conflict", + [{"a": 1}, {"a": {"b": 2}}], + None, + "Duplicate keys found", + ), + ], +) +def test_body_json_requests(test_name, mappings, expected_result, expected_error): + """Test nested structure support for body_json requests""" + if expected_error: + with pytest.raises(ValueError, match=expected_error): + combine_mappings(mappings, allow_same_value_merge=True) + else: + assert combine_mappings(mappings, allow_same_value_merge=True) == expected_result From ca68c5c289ae25d592d47efc2351d4a0aa729857 Mon Sep 17 00:00:00 2001 From: Brian Lai <51336873+brianjlai@users.noreply.github.com> Date: Tue, 4 Feb 2025 13:40:10 -0800 Subject: [PATCH 11/12] fix(concurrent cdk): Properly call set_initial_state() on the cursor that is initialized on the ClientSideIncrementalRecordFilterDecorator (#310) --- .../concurrent_declarative_source.py | 15 ++++++- .../test_concurrent_declarative_source.py | 41 +++++++++++++++++++ 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py index 92f4bdc4b..d4ecc0084 100644 --- a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py +++ b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py @@ -475,10 +475,21 @@ def _get_retriever( # Also a temporary hack. In the legacy Stream implementation, as part of the read, # set_initial_state() is called to instantiate incoming state on the cursor. Although we no # longer rely on the legacy low-code cursor for concurrent checkpointing, low-code components - # like StopConditionPaginationStrategyDecorator and ClientSideIncrementalRecordFilterDecorator - # still rely on a DatetimeBasedCursor that is properly initialized with state. + # like StopConditionPaginationStrategyDecorator still rely on a DatetimeBasedCursor that is + # properly initialized with state. if retriever.cursor: retriever.cursor.set_initial_state(stream_state=stream_state) + + # Similar to above, the ClientSideIncrementalRecordFilterDecorator cursor is a separate instance + # from the one initialized on the SimpleRetriever, so it also must also have state initialized + # for semi-incremental streams using is_client_side_incremental to filter properly + if isinstance(retriever.record_selector, RecordSelector) and isinstance( + retriever.record_selector.record_filter, ClientSideIncrementalRecordFilterDecorator + ): + retriever.record_selector.record_filter._cursor.set_initial_state( + stream_state=stream_state + ) # type: ignore # After non-concurrent cursors are deprecated we can remove these cursor workarounds + # We zero it out here, but since this is a cursor reference, the state is still properly # instantiated for the other components that reference it retriever.cursor = None diff --git a/unit_tests/sources/declarative/test_concurrent_declarative_source.py b/unit_tests/sources/declarative/test_concurrent_declarative_source.py index 892876850..1877e11bb 100644 --- a/unit_tests/sources/declarative/test_concurrent_declarative_source.py +++ b/unit_tests/sources/declarative/test_concurrent_declarative_source.py @@ -32,6 +32,9 @@ ConcurrentDeclarativeSource, ) from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream +from airbyte_cdk.sources.declarative.extractors.record_filter import ( + ClientSideIncrementalRecordFilterDecorator, +) from airbyte_cdk.sources.declarative.partition_routers import AsyncJobPartitionRouter from airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator import ( StreamSlicerPartitionGenerator, @@ -1647,6 +1650,44 @@ def test_async_incremental_stream_uses_concurrent_cursor_with_state(): assert async_job_partition_router.stream_slicer._concurrent_state == expected_state +def test_stream_using_is_client_side_incremental_has_cursor_state(): + expected_cursor_value = "2024-07-01" + state = [ + AirbyteStateMessage( + type=AirbyteStateType.STREAM, + stream=AirbyteStreamState( + stream_descriptor=StreamDescriptor(name="locations", namespace=None), + stream_state=AirbyteStateBlob(updated_at=expected_cursor_value), + ), + ) + ] + + manifest_with_stream_state_interpolation = copy.deepcopy(_MANIFEST) + + # Enable semi-incremental on the locations stream + manifest_with_stream_state_interpolation["definitions"]["locations_stream"]["incremental_sync"][ + "is_client_side_incremental" + ] = True + + source = ConcurrentDeclarativeSource( + source_config=manifest_with_stream_state_interpolation, + config=_CONFIG, + catalog=_CATALOG, + state=state, + ) + concurrent_streams, synchronous_streams = source._group_streams(config=_CONFIG) + + locations_stream = concurrent_streams[2] + assert isinstance(locations_stream, DefaultStream) + + simple_retriever = locations_stream._stream_partition_generator._partition_factory._retriever + record_filter = simple_retriever.record_selector.record_filter + assert isinstance(record_filter, ClientSideIncrementalRecordFilterDecorator) + client_side_incremental_cursor_state = record_filter._cursor._cursor + + assert client_side_incremental_cursor_state == expected_cursor_value + + def create_wrapped_stream(stream: DeclarativeStream) -> Stream: slice_to_records_mapping = get_mocked_read_records_output(stream_name=stream.name) From e38f914bc49912ae99a7bb1d4a0b5f50393fee7f Mon Sep 17 00:00:00 2001 From: "Aaron (\"AJ\") Steers" Date: Thu, 6 Feb 2025 08:21:44 -0800 Subject: [PATCH 12/12] Chore: add new test using `source-pokeapi` and custom `components.py` (#317) --- .../source_pokeapi_w_components_py/README.md | 3 + .../components.py | 20 + .../components_failing.py | 24 + .../manifest.yaml | 980 ++++++++++++++++++ .../valid_config.yaml | 1 + .../source_the_guardian_api/.gitignore | 1 - .../source_the_guardian_api/README.md | 9 - .../source_the_guardian_api/components.py | 61 -- .../components_failing.py | 54 - .../source_the_guardian_api/manifest.yaml | 376 ------- .../source_the_guardian_api/valid_config.yaml | 1 - ..._source_declarative_w_custom_components.py | 49 +- 12 files changed, 1046 insertions(+), 533 deletions(-) create mode 100644 unit_tests/source_declarative_manifest/resources/source_pokeapi_w_components_py/README.md create mode 100644 unit_tests/source_declarative_manifest/resources/source_pokeapi_w_components_py/components.py create mode 100644 unit_tests/source_declarative_manifest/resources/source_pokeapi_w_components_py/components_failing.py create mode 100644 unit_tests/source_declarative_manifest/resources/source_pokeapi_w_components_py/manifest.yaml create mode 100644 unit_tests/source_declarative_manifest/resources/source_pokeapi_w_components_py/valid_config.yaml delete mode 100644 unit_tests/source_declarative_manifest/resources/source_the_guardian_api/.gitignore delete mode 100644 unit_tests/source_declarative_manifest/resources/source_the_guardian_api/README.md delete mode 100644 unit_tests/source_declarative_manifest/resources/source_the_guardian_api/components.py delete mode 100644 unit_tests/source_declarative_manifest/resources/source_the_guardian_api/components_failing.py delete mode 100644 unit_tests/source_declarative_manifest/resources/source_the_guardian_api/manifest.yaml delete mode 100644 unit_tests/source_declarative_manifest/resources/source_the_guardian_api/valid_config.yaml diff --git a/unit_tests/source_declarative_manifest/resources/source_pokeapi_w_components_py/README.md b/unit_tests/source_declarative_manifest/resources/source_pokeapi_w_components_py/README.md new file mode 100644 index 000000000..78505726c --- /dev/null +++ b/unit_tests/source_declarative_manifest/resources/source_pokeapi_w_components_py/README.md @@ -0,0 +1,3 @@ +# PokeAPI with Custom `components.py` API Tests + +This test connector is a modified version of `source-pokeapi`. It has been modified to use custom `components.py` so we have a test case the completes quickly and _does not_ require any credentials. diff --git a/unit_tests/source_declarative_manifest/resources/source_pokeapi_w_components_py/components.py b/unit_tests/source_declarative_manifest/resources/source_pokeapi_w_components_py/components.py new file mode 100644 index 000000000..5e7e16f71 --- /dev/null +++ b/unit_tests/source_declarative_manifest/resources/source_pokeapi_w_components_py/components.py @@ -0,0 +1,20 @@ +"""A sample implementation of custom components that does nothing but will cause syncs to fail if missing.""" + +from typing import Any, Mapping + +import requests + +from airbyte_cdk.sources.declarative.extractors import DpathExtractor + + +class IntentionalException(Exception): + """This exception is raised intentionally in order to test error handling.""" + + +class MyCustomExtractor(DpathExtractor): + """Dummy class, directly implements DPatchExtractor. + + Used to prove that SDM can find the custom class by name. + """ + + pass diff --git a/unit_tests/source_declarative_manifest/resources/source_pokeapi_w_components_py/components_failing.py b/unit_tests/source_declarative_manifest/resources/source_pokeapi_w_components_py/components_failing.py new file mode 100644 index 000000000..5c05881e7 --- /dev/null +++ b/unit_tests/source_declarative_manifest/resources/source_pokeapi_w_components_py/components_failing.py @@ -0,0 +1,24 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# +"""A sample implementation of custom components that does nothing but will cause syncs to fail if missing.""" + +from collections.abc import Iterable, MutableMapping +from dataclasses import InitVar, dataclass +from typing import Any, Mapping, Optional, Union + +import requests + +from airbyte_cdk.sources.declarative.extractors import DpathExtractor + + +class IntentionalException(Exception): + """This exception is raised intentionally in order to test error handling.""" + + +class MyCustomExtractor(DpathExtractor): + def extract_records( + self, + response: requests.Response, + ) -> Iterable[MutableMapping[Any, Any]]: + raise IntentionalException diff --git a/unit_tests/source_declarative_manifest/resources/source_pokeapi_w_components_py/manifest.yaml b/unit_tests/source_declarative_manifest/resources/source_pokeapi_w_components_py/manifest.yaml new file mode 100644 index 000000000..af19485fa --- /dev/null +++ b/unit_tests/source_declarative_manifest/resources/source_pokeapi_w_components_py/manifest.yaml @@ -0,0 +1,980 @@ +version: 6.30.0 + +type: DeclarativeSource + +check: + type: CheckStream + stream_names: + - pokemon + +definitions: + streams: + pokemon: + type: DeclarativeStream + name: pokemon + retriever: + type: SimpleRetriever + requester: + $ref: "#/definitions/base_requester" + path: /{{config['pokemon_name']}} + http_method: GET + record_selector: + type: RecordSelector + extractor: + # Simple wrapper around `DpathExtractor` + type: CustomRecordExtractor + class_name: components.MyCustomExtractor + field_path: [] + primary_key: + - id + schema_loader: + type: InlineSchemaLoader + # type: CustomSchemaLoader + schema: + $ref: "#/schemas/pokemon" + # class_name: components.MyCustomInlineSchemaLoader + base_requester: + type: HttpRequester + url_base: https://pokeapi.co/api/v2/pokemon + +streams: + - $ref: "#/definitions/streams/pokemon" + +spec: + type: Spec + connection_specification: + type: object + $schema: http://json-schema.org/draft-07/schema# + required: + - pokemon_name + properties: + pokemon_name: + type: string + description: Pokemon requested from the API. + enum: + - bulbasaur + - ivysaur + - venusaur + - charmander + - charmeleon + - charizard + - squirtle + - wartortle + - blastoise + - caterpie + - metapod + - butterfree + - weedle + - kakuna + - beedrill + - pidgey + - pidgeotto + - pidgeot + - rattata + - raticate + - spearow + - fearow + - ekans + - arbok + - pikachu + - raichu + - sandshrew + - sandslash + - nidoranf + - nidorina + - nidoqueen + - nidoranm + - nidorino + - nidoking + - clefairy + - clefable + - vulpix + - ninetales + - jigglypuff + - wigglytuff + - zubat + - golbat + - oddish + - gloom + - vileplume + - paras + - parasect + - venonat + - venomoth + - diglett + - dugtrio + - meowth + - persian + - psyduck + - golduck + - mankey + - primeape + - growlithe + - arcanine + - poliwag + - poliwhirl + - poliwrath + - abra + - kadabra + - alakazam + - machop + - machoke + - machamp + - bellsprout + - weepinbell + - victreebel + - tentacool + - tentacruel + - geodude + - graveler + - golem + - ponyta + - rapidash + - slowpoke + - slowbro + - magnemite + - magneton + - farfetchd + - doduo + - dodrio + - seel + - dewgong + - grimer + - muk + - shellder + - cloyster + - gastly + - haunter + - gengar + - onix + - drowzee + - hypno + - krabby + - kingler + - voltorb + - electrode + - exeggcute + - exeggutor + - cubone + - marowak + - hitmonlee + - hitmonchan + - lickitung + - koffing + - weezing + - rhyhorn + - rhydon + - chansey + - tangela + - kangaskhan + - horsea + - seadra + - goldeen + - seaking + - staryu + - starmie + - mrmime + - scyther + - jynx + - electabuzz + - magmar + - pinsir + - tauros + - magikarp + - gyarados + - lapras + - ditto + - eevee + - vaporeon + - jolteon + - flareon + - porygon + - omanyte + - omastar + - kabuto + - kabutops + - aerodactyl + - snorlax + - articuno + - zapdos + - moltres + - dratini + - dragonair + - dragonite + - mewtwo + - mew + - chikorita + - bayleef + - meganium + - cyndaquil + - quilava + - typhlosion + - totodile + - croconaw + - feraligatr + - sentret + - furret + - hoothoot + - noctowl + - ledyba + - ledian + - spinarak + - ariados + - crobat + - chinchou + - lanturn + - pichu + - cleffa + - igglybuff + - togepi + - togetic + - natu + - xatu + - mareep + - flaaffy + - ampharos + - bellossom + - marill + - azumarill + - sudowoodo + - politoed + - hoppip + - skiploom + - jumpluff + - aipom + - sunkern + - sunflora + - yanma + - wooper + - quagsire + - espeon + - umbreon + - murkrow + - slowking + - misdreavus + - unown + - wobbuffet + - girafarig + - pineco + - forretress + - dunsparce + - gligar + - steelix + - snubbull + - granbull + - qwilfish + - scizor + - shuckle + - heracross + - sneasel + - teddiursa + - ursaring + - slugma + - magcargo + - swinub + - piloswine + - corsola + - remoraid + - octillery + - delibird + - mantine + - skarmory + - houndour + - houndoom + - kingdra + - phanpy + - donphan + - porygon2 + - stantler + - smeargle + - tyrogue + - hitmontop + - smoochum + - elekid + - magby + - miltank + - blissey + - raikou + - entei + - suicune + - larvitar + - pupitar + - tyranitar + - lugia + - ho-oh + - celebi + - treecko + - grovyle + - sceptile + - torchic + - combusken + - blaziken + - mudkip + - marshtomp + - swampert + - poochyena + - mightyena + - zigzagoon + - linoone + - wurmple + - silcoon + - beautifly + - cascoon + - dustox + - lotad + - lombre + - ludicolo + - seedot + - nuzleaf + - shiftry + - taillow + - swellow + - wingull + - pelipper + - ralts + - kirlia + - gardevoir + - surskit + - masquerain + - shroomish + - breloom + - slakoth + - vigoroth + - slaking + - nincada + - ninjask + - shedinja + - whismur + - loudred + - exploud + - makuhita + - hariyama + - azurill + - nosepass + - skitty + - delcatty + - sableye + - mawile + - aron + - lairon + - aggron + - meditite + - medicham + - electrike + - manectric + - plusle + - minun + - volbeat + - illumise + - roselia + - gulpin + - swalot + - carvanha + - sharpedo + - wailmer + - wailord + - numel + - camerupt + - torkoal + - spoink + - grumpig + - spinda + - trapinch + - vibrava + - flygon + - cacnea + - cacturne + - swablu + - altaria + - zangoose + - seviper + - lunatone + - solrock + - barboach + - whiscash + - corphish + - crawdaunt + - baltoy + - claydol + - lileep + - cradily + - anorith + - armaldo + - feebas + - milotic + - castform + - kecleon + - shuppet + - banette + - duskull + - dusclops + - tropius + - chimecho + - absol + - wynaut + - snorunt + - glalie + - spheal + - sealeo + - walrein + - clamperl + - huntail + - gorebyss + - relicanth + - luvdisc + - bagon + - shelgon + - salamence + - beldum + - metang + - metagross + - regirock + - regice + - registeel + - latias + - latios + - kyogre + - groudon + - rayquaza + - jirachi + - deoxys + - turtwig + - grotle + - torterra + - chimchar + - monferno + - infernape + - piplup + - prinplup + - empoleon + - starly + - staravia + - staraptor + - bidoof + - bibarel + - kricketot + - kricketune + - shinx + - luxio + - luxray + - budew + - roserade + - cranidos + - rampardos + - shieldon + - bastiodon + - burmy + - wormadam + - mothim + - combee + - vespiquen + - pachirisu + - buizel + - floatzel + - cherubi + - cherrim + - shellos + - gastrodon + - ambipom + - drifloon + - drifblim + - buneary + - lopunny + - mismagius + - honchkrow + - glameow + - purugly + - chingling + - stunky + - skuntank + - bronzor + - bronzong + - bonsly + - mimejr + - happiny + - chatot + - spiritomb + - gible + - gabite + - garchomp + - munchlax + - riolu + - lucario + - hippopotas + - hippowdon + - skorupi + - drapion + - croagunk + - toxicroak + - carnivine + - finneon + - lumineon + - mantyke + - snover + - abomasnow + - weavile + - magnezone + - lickilicky + - rhyperior + - tangrowth + - electivire + - magmortar + - togekiss + - yanmega + - leafeon + - glaceon + - gliscor + - mamoswine + - porygon-z + - gallade + - probopass + - dusknoir + - froslass + - rotom + - uxie + - mesprit + - azelf + - dialga + - palkia + - heatran + - regigigas + - giratina + - cresselia + - phione + - manaphy + - darkrai + - shaymin + - arceus + - victini + - snivy + - servine + - serperior + - tepig + - pignite + - emboar + - oshawott + - dewott + - samurott + - patrat + - watchog + - lillipup + - herdier + - stoutland + - purrloin + - liepard + - pansage + - simisage + - pansear + - simisear + - panpour + - simipour + - munna + - musharna + - pidove + - tranquill + - unfezant + - blitzle + - zebstrika + - roggenrola + - boldore + - gigalith + - woobat + - swoobat + - drilbur + - excadrill + - audino + - timburr + - gurdurr + - conkeldurr + - tympole + - palpitoad + - seismitoad + - throh + - sawk + - sewaddle + - swadloon + - leavanny + - venipede + - whirlipede + - scolipede + - cottonee + - whimsicott + - petilil + - lilligant + - basculin + - sandile + - krokorok + - krookodile + - darumaka + - darmanitan + - maractus + - dwebble + - crustle + - scraggy + - scrafty + - sigilyph + - yamask + - cofagrigus + - tirtouga + - carracosta + - archen + - archeops + - trubbish + - garbodor + - zorua + - zoroark + - minccino + - cinccino + - gothita + - gothorita + - gothitelle + - solosis + - duosion + - reuniclus + - ducklett + - swanna + - vanillite + - vanillish + - vanilluxe + - deerling + - sawsbuck + - emolga + - karrablast + - escavalier + - foongus + - amoonguss + - frillish + - jellicent + - alomomola + - joltik + - galvantula + - ferroseed + - ferrothorn + - klink + - klang + - klinklang + - tynamo + - eelektrik + - eelektross + - elgyem + - beheeyem + - litwick + - lampent + - chandelure + - axew + - fraxure + - haxorus + - cubchoo + - beartic + - cryogonal + - shelmet + - accelgor + - stunfisk + - mienfoo + - mienshao + - druddigon + - golett + - golurk + - pawniard + - bisharp + - bouffalant + - rufflet + - braviary + - vullaby + - mandibuzz + - heatmor + - durant + - deino + - zweilous + - hydreigon + - larvesta + - volcarona + - cobalion + - terrakion + - virizion + - tornadus + - thundurus + - reshiram + - zekrom + - landorus + - kyurem + - keldeo + - meloetta + - genesect + - chespin + - quilladin + - chesnaught + - fennekin + - braixen + - delphox + - froakie + - frogadier + - greninja + - bunnelby + - diggersby + - fletchling + - fletchinder + - talonflame + - scatterbug + - spewpa + - vivillon + - litleo + - pyroar + - flabebe + - floette + - florges + - skiddo + - gogoat + - pancham + - pangoro + - furfrou + - espurr + - meowstic + - honedge + - doublade + - aegislash + - spritzee + - aromatisse + - swirlix + - slurpuff + - inkay + - malamar + - binacle + - barbaracle + - skrelp + - dragalge + - clauncher + - clawitzer + - helioptile + - heliolisk + - tyrunt + - tyrantrum + - amaura + - aurorus + - sylveon + - hawlucha + - dedenne + - carbink + - goomy + - sliggoo + - goodra + - klefki + - phantump + - trevenant + - pumpkaboo + - gourgeist + - bergmite + - avalugg + - noibat + - noivern + - xerneas + - yveltal + - zygarde + - diancie + - hoopa + - volcanion + - rowlet + - dartrix + - decidueye + - litten + - torracat + - incineroar + - popplio + - brionne + - primarina + - pikipek + - trumbeak + - toucannon + - yungoos + - gumshoos + - grubbin + - charjabug + - vikavolt + - crabrawler + - crabominable + - oricorio + - cutiefly + - ribombee + - rockruff + - lycanroc + - wishiwashi + - mareanie + - toxapex + - mudbray + - mudsdale + - dewpider + - araquanid + - fomantis + - lurantis + - morelull + - shiinotic + - salandit + - salazzle + - stufful + - bewear + - bounsweet + - steenee + - tsareena + - comfey + - oranguru + - passimian + - wimpod + - golisopod + - sandygast + - palossand + - pyukumuku + - typenull + - silvally + - minior + - komala + - turtonator + - togedemaru + - mimikyu + - bruxish + - drampa + - dhelmise + - jangmo-o + - hakamo-o + - kommo-o + - tapukoko + - tapulele + - tapubulu + - tapufini + - cosmog + - cosmoem + - solgaleo + - lunala + - nihilego + - buzzwole + - pheromosa + - xurkitree + - celesteela + - kartana + - guzzlord + - necrozma + - magearna + - marshadow + - poipole + - naganadel + - stakataka + - blacephalon + - zeraora + - meltan + - melmetal + - grookey + - thwackey + - rillaboom + - scorbunny + - raboot + - cinderace + - sobble + - drizzile + - inteleon + - skwovet + - greedent + - rookidee + - corvisquire + - corviknight + - blipbug + - dottler + - orbeetle + - nickit + - thievul + - gossifleur + - eldegoss + - wooloo + - dubwool + - chewtle + - drednaw + - yamper + - boltund + - rolycoly + - carkol + - coalossal + - applin + - flapple + - appletun + - silicobra + - sandaconda + - cramorant + - arrokuda + - barraskewda + - toxel + - toxtricity + - sizzlipede + - centiskorch + - clobbopus + - grapploct + - sinistea + - polteageist + - hatenna + - hattrem + - hatterene + - impidimp + - morgrem + - grimmsnarl + - obstagoon + - perrserker + - cursola + - sirfetchd + - mrrime + - runerigus + - milcery + - alcremie + - falinks + - pincurchin + - snom + - frosmoth + - stonjourner + - eiscue + - indeedee + - morpeko + - cufant + - copperajah + - dracozolt + - arctozolt + - dracovish + - arctovish + - duraludon + - dreepy + - drakloak + - dragapult + - zacian + - zamazenta + - eternatus + - kubfu + - urshifu + - zarude + - regieleki + - regidrago + - glastrier + - spectrier + - calyrex + order: 0 + title: Pokemon Name + pattern: ^[a-z0-9_\-]+$ + examples: + - ditto + - luxray + - snorlax + additionalProperties: true + +metadata: + assist: {} + testedStreams: + pokemon: + hasRecords: true + streamHash: 71d50b057104f51772e5ef731e580332145d89dd + hasResponse: true + primaryKeysAreUnique: true + primaryKeysArePresent: true + responsesAreSuccessful: true + autoImportSchema: + pokemon: false + +schemas: + pokemon: + type: object + $schema: http://json-schema.org/draft-07/schema# + properties: {} + additionalProperties: true diff --git a/unit_tests/source_declarative_manifest/resources/source_pokeapi_w_components_py/valid_config.yaml b/unit_tests/source_declarative_manifest/resources/source_pokeapi_w_components_py/valid_config.yaml new file mode 100644 index 000000000..78af092bb --- /dev/null +++ b/unit_tests/source_declarative_manifest/resources/source_pokeapi_w_components_py/valid_config.yaml @@ -0,0 +1 @@ +{ "start_date": "2024-01-01", "pokemon": "pikachu" } diff --git a/unit_tests/source_declarative_manifest/resources/source_the_guardian_api/.gitignore b/unit_tests/source_declarative_manifest/resources/source_the_guardian_api/.gitignore deleted file mode 100644 index c4ab49a30..000000000 --- a/unit_tests/source_declarative_manifest/resources/source_the_guardian_api/.gitignore +++ /dev/null @@ -1 +0,0 @@ -secrets* diff --git a/unit_tests/source_declarative_manifest/resources/source_the_guardian_api/README.md b/unit_tests/source_declarative_manifest/resources/source_the_guardian_api/README.md deleted file mode 100644 index 403a4baba..000000000 --- a/unit_tests/source_declarative_manifest/resources/source_the_guardian_api/README.md +++ /dev/null @@ -1,9 +0,0 @@ -# The Guardian API Tests - -For these tests to work, you'll need to create a `secrets.yaml` file in this directory that looks like this: - -```yml -api_key: ****** -``` - -The `.gitignore` file in this directory should ensure your file is not committed to git, but it's a good practice to double-check. 👀 diff --git a/unit_tests/source_declarative_manifest/resources/source_the_guardian_api/components.py b/unit_tests/source_declarative_manifest/resources/source_the_guardian_api/components.py deleted file mode 100644 index 98a9f7ad5..000000000 --- a/unit_tests/source_declarative_manifest/resources/source_the_guardian_api/components.py +++ /dev/null @@ -1,61 +0,0 @@ -# -# Copyright (c) 2023 Airbyte, Inc., all rights reserved. -# - -from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Optional, Union - -import requests - -from airbyte_cdk.sources.declarative.interpolation import InterpolatedString -from airbyte_cdk.sources.declarative.requesters.paginators import PaginationStrategy -from airbyte_cdk.sources.declarative.types import Config, Record - - -@dataclass -class CustomPageIncrement(PaginationStrategy): - """ - Starts page from 1 instead of the default value that is 0. Stops Pagination when currentPage is equal to totalPages. - """ - - config: Config - page_size: Optional[Union[str, int]] - parameters: InitVar[Mapping[str, Any]] - start_from_page: int = 0 - inject_on_first_request: bool = False - - def __post_init__(self, parameters: Mapping[str, Any]) -> None: - if isinstance(self.page_size, int) or (self.page_size is None): - self._page_size = self.page_size - else: - page_size = InterpolatedString(self.page_size, parameters=parameters).eval(self.config) - if not isinstance(page_size, int): - raise Exception(f"{page_size} is of type {type(page_size)}. Expected {int}") - self._page_size = page_size - - @property - def initial_token(self) -> Optional[Any]: - if self.inject_on_first_request: - return self.start_from_page - return None - - def next_page_token( - self, - response: requests.Response, - last_page_size: int, - last_record: Optional[Record], - last_page_token_value: Optional[Any], - ) -> Optional[Any]: - res = response.json().get("response") - current_page = res.get("currentPage") - total_pages = res.get("pages") - - # The first request to the API does not include the page_token, so it comes in as None when determing whether to paginate - last_page_token_value = last_page_token_value or 0 - if current_page < total_pages: - return last_page_token_value + 1 - else: - return None - - def get_page_size(self) -> Optional[int]: - return self._page_size diff --git a/unit_tests/source_declarative_manifest/resources/source_the_guardian_api/components_failing.py b/unit_tests/source_declarative_manifest/resources/source_the_guardian_api/components_failing.py deleted file mode 100644 index 8655bdf2d..000000000 --- a/unit_tests/source_declarative_manifest/resources/source_the_guardian_api/components_failing.py +++ /dev/null @@ -1,54 +0,0 @@ -# -# Copyright (c) 2023 Airbyte, Inc., all rights reserved. -# - -from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Optional, Union - -import requests - -from airbyte_cdk.sources.declarative.interpolation import InterpolatedString -from airbyte_cdk.sources.declarative.requesters.paginators import PaginationStrategy -from airbyte_cdk.sources.declarative.types import Config, Record - - -class IntentionalException(Exception): - """This exception is raised intentionally in order to test error handling.""" - - -@dataclass -class CustomPageIncrement(PaginationStrategy): - """ - Starts page from 1 instead of the default value that is 0. Stops Pagination when currentPage is equal to totalPages. - """ - - config: Config - page_size: Optional[Union[str, int]] - parameters: InitVar[Mapping[str, Any]] - start_from_page: int = 0 - inject_on_first_request: bool = False - - def __post_init__(self, parameters: Mapping[str, Any]) -> None: - if isinstance(self.page_size, int) or (self.page_size is None): - self._page_size = self.page_size - else: - page_size = InterpolatedString(self.page_size, parameters=parameters).eval(self.config) - if not isinstance(page_size, int): - raise Exception(f"{page_size} is of type {type(page_size)}. Expected {int}") - self._page_size = page_size - - @property - def initial_token(self) -> Optional[Any]: - raise IntentionalException() - - def next_page_token( - self, - response: requests.Response, - last_page_size: int, - last_record: Optional[Record], - last_page_token_value: Optional[Any], - ) -> Optional[Any]: - raise IntentionalException() - - def get_page_size(self) -> Optional[int]: - return self._page_size diff --git a/unit_tests/source_declarative_manifest/resources/source_the_guardian_api/manifest.yaml b/unit_tests/source_declarative_manifest/resources/source_the_guardian_api/manifest.yaml deleted file mode 100644 index a42e0ebba..000000000 --- a/unit_tests/source_declarative_manifest/resources/source_the_guardian_api/manifest.yaml +++ /dev/null @@ -1,376 +0,0 @@ -version: "4.3.2" -definitions: - selector: - extractor: - field_path: - - response - - results - requester: - url_base: "https://content.guardianapis.com" - http_method: "GET" - request_parameters: - api-key: "{{ config['api_key'] }}" - q: "{{ config['query'] }}" - tag: "{{ config['tag'] }}" - section: "{{ config['section'] }}" - order-by: "oldest" - incremental_sync: - type: DatetimeBasedCursor - start_datetime: - datetime: "{{ config['start_date'] }}" - datetime_format: "%Y-%m-%d" - end_datetime: - datetime: "{{ config['end_date'] or now_utc().strftime('%Y-%m-%d') }}" - datetime_format: "%Y-%m-%d" - step: "P7D" - datetime_format: "%Y-%m-%dT%H:%M:%SZ" - cursor_granularity: "PT1S" - cursor_field: "webPublicationDate" - start_time_option: - field_name: "from-date" - inject_into: "request_parameter" - end_time_option: - field_name: "to-date" - inject_into: "request_parameter" - retriever: - record_selector: - extractor: - field_path: - - response - - results - paginator: - type: DefaultPaginator - pagination_strategy: - type: CustomPaginationStrategy - class_name: "CustomPageIncrement" - page_size: 10 - page_token_option: - type: RequestOption - inject_into: "request_parameter" - field_name: "page" - page_size_option: - inject_into: "body_data" - field_name: "page_size" - requester: - url_base: "https://content.guardianapis.com" - http_method: "GET" - request_parameters: - api-key: "{{ config['api_key'] }}" - q: "{{ config['query'] }}" - tag: "{{ config['tag'] }}" - section: "{{ config['section'] }}" - order-by: "oldest" - base_stream: - incremental_sync: - type: DatetimeBasedCursor - start_datetime: - datetime: "{{ config['start_date'] }}" - datetime_format: "%Y-%m-%d" - end_datetime: - datetime: "{{ config['end_date'] or now_utc().strftime('%Y-%m-%d') }}" - datetime_format: "%Y-%m-%d" - step: "P7D" - datetime_format: "%Y-%m-%dT%H:%M:%SZ" - cursor_granularity: "PT1S" - cursor_field: "webPublicationDate" - start_time_option: - field_name: "from-date" - inject_into: "request_parameter" - end_time_option: - field_name: "to-date" - inject_into: "request_parameter" - retriever: - record_selector: - extractor: - field_path: - - response - - results - paginator: - type: DefaultPaginator - pagination_strategy: - type: CustomPaginationStrategy - class_name: "CustomPageIncrement" - page_size: 10 - page_token_option: - type: RequestOption - inject_into: "request_parameter" - field_name: "page" - page_size_option: - inject_into: "body_data" - field_name: "page_size" - requester: - url_base: "https://content.guardianapis.com" - http_method: "GET" - request_parameters: - api-key: "{{ config['api_key'] }}" - q: "{{ config['query'] }}" - tag: "{{ config['tag'] }}" - section: "{{ config['section'] }}" - order-by: "oldest" - content_stream: - incremental_sync: - type: DatetimeBasedCursor - start_datetime: - datetime: "{{ config['start_date'] }}" - datetime_format: "%Y-%m-%d" - end_datetime: - datetime: "{{ config['end_date'] or now_utc().strftime('%Y-%m-%d') }}" - datetime_format: "%Y-%m-%d" - step: "P7D" - datetime_format: "%Y-%m-%dT%H:%M:%SZ" - cursor_granularity: "PT1S" - cursor_field: "webPublicationDate" - start_time_option: - field_name: "from-date" - inject_into: "request_parameter" - end_time_option: - field_name: "to-date" - inject_into: "request_parameter" - retriever: - record_selector: - extractor: - field_path: - - response - - results - paginator: - type: "DefaultPaginator" - pagination_strategy: - type: CustomPaginationStrategy - class_name: "components.CustomPageIncrement" - page_size: 10 - page_token_option: - type: RequestOption - inject_into: "request_parameter" - field_name: "page" - page_size_option: - inject_into: "body_data" - field_name: "page_size" - requester: - url_base: "https://content.guardianapis.com" - http_method: "GET" - request_parameters: - api-key: "{{ config['api_key'] }}" - q: "{{ config['query'] }}" - tag: "{{ config['tag'] }}" - section: "{{ config['section'] }}" - order-by: "oldest" - schema_loader: - type: InlineSchemaLoader - schema: - $schema: http://json-schema.org/draft-04/schema# - type: object - properties: - id: - type: string - type: - type: string - sectionId: - type: string - sectionName: - type: string - webPublicationDate: - type: string - webTitle: - type: string - webUrl: - type: string - apiUrl: - type: string - isHosted: - type: boolean - pillarId: - type: string - pillarName: - type: string - required: - - id - - type - - sectionId - - sectionName - - webPublicationDate - - webTitle - - webUrl - - apiUrl - - isHosted - - pillarId - - pillarName -streams: - - incremental_sync: - type: DatetimeBasedCursor - start_datetime: - datetime: "{{ config['start_date'] }}" - datetime_format: "%Y-%m-%d" - type: MinMaxDatetime - end_datetime: - datetime: "{{ config['end_date'] or now_utc().strftime('%Y-%m-%d') }}" - datetime_format: "%Y-%m-%d" - type: MinMaxDatetime - step: "P7D" - datetime_format: "%Y-%m-%dT%H:%M:%SZ" - cursor_granularity: "PT1S" - cursor_field: "webPublicationDate" - start_time_option: - field_name: "from-date" - inject_into: "request_parameter" - type: RequestOption - end_time_option: - field_name: "to-date" - inject_into: "request_parameter" - type: RequestOption - retriever: - record_selector: - extractor: - field_path: - - response - - results - type: DpathExtractor - type: RecordSelector - paginator: - type: "DefaultPaginator" - pagination_strategy: - class_name: components.CustomPageIncrement - page_size: 10 - type: CustomPaginationStrategy - page_token_option: - type: RequestOption - inject_into: "request_parameter" - field_name: "page" - page_size_option: - inject_into: "body_data" - field_name: "page_size" - type: RequestOption - requester: - url_base: "https://content.guardianapis.com" - http_method: "GET" - request_parameters: - api-key: "{{ config['api_key'] }}" - q: "{{ config['query'] }}" - tag: "{{ config['tag'] }}" - section: "{{ config['section'] }}" - order-by: "oldest" - type: HttpRequester - path: "/search" - type: SimpleRetriever - schema_loader: - type: InlineSchemaLoader - schema: - $schema: http://json-schema.org/draft-04/schema# - type: object - properties: - id: - type: string - type: - type: string - sectionId: - type: string - sectionName: - type: string - webPublicationDate: - type: string - webTitle: - type: string - webUrl: - type: string - apiUrl: - type: string - isHosted: - type: boolean - pillarId: - type: string - pillarName: - type: string - required: - - id - - type - - sectionId - - sectionName - - webPublicationDate - - webTitle - - webUrl - - apiUrl - - isHosted - - pillarId - - pillarName - type: DeclarativeStream - name: "content" - primary_key: "id" -check: - stream_names: - - "content" - type: CheckStream -type: DeclarativeSource -spec: - type: Spec - documentation_url: https://docs.airbyte.com/integrations/sources/the-guardian-api - connection_specification: - $schema: http://json-schema.org/draft-07/schema# - title: The Guardian Api Spec - type: object - required: - - api_key - - start_date - additionalProperties: true - properties: - api_key: - title: API Key - type: string - description: - Your API Key. See here. - The key is case sensitive. - airbyte_secret: true - start_date: - title: Start Date - type: string - description: - Use this to set the minimum date (YYYY-MM-DD) of the results. - Results older than the start_date will not be shown. - pattern: ^([1-9][0-9]{3})\-(0?[1-9]|1[012])\-(0?[1-9]|[12][0-9]|3[01])$ - examples: - - YYYY-MM-DD - query: - title: Query - type: string - description: - (Optional) The query (q) parameter filters the results to only - those that include that search term. The q parameter supports AND, OR and - NOT operators. - examples: - - environment AND NOT water - - environment AND political - - amusement park - - political - tag: - title: Tag - type: string - description: - (Optional) A tag is a piece of data that is used by The Guardian - to categorise content. Use this parameter to filter results by showing only - the ones matching the entered tag. See here - for a list of all tags, and here - for the tags endpoint documentation. - examples: - - environment/recycling - - environment/plasticbags - - environment/energyefficiency - section: - title: Section - type: string - description: - (Optional) Use this to filter the results by a particular section. - See here - for a list of all sections, and here - for the sections endpoint documentation. - examples: - - media - - technology - - housing-network - end_date: - title: End Date - type: string - description: - (Optional) Use this to set the maximum date (YYYY-MM-DD) of the - results. Results newer than the end_date will not be shown. Default is set - to the current date (today) for incremental syncs. - pattern: ^([1-9][0-9]{3})\-(0?[1-9]|1[012])\-(0?[1-9]|[12][0-9]|3[01])$ - examples: - - YYYY-MM-DD diff --git a/unit_tests/source_declarative_manifest/resources/source_the_guardian_api/valid_config.yaml b/unit_tests/source_declarative_manifest/resources/source_the_guardian_api/valid_config.yaml deleted file mode 100644 index b2f752ea1..000000000 --- a/unit_tests/source_declarative_manifest/resources/source_the_guardian_api/valid_config.yaml +++ /dev/null @@ -1 +0,0 @@ -{ "start_date": "2024-01-01" } diff --git a/unit_tests/source_declarative_manifest/test_source_declarative_w_custom_components.py b/unit_tests/source_declarative_manifest/test_source_declarative_w_custom_components.py index d608e7620..40bb6d40b 100644 --- a/unit_tests/source_declarative_manifest/test_source_declarative_w_custom_components.py +++ b/unit_tests/source_declarative_manifest/test_source_declarative_w_custom_components.py @@ -89,9 +89,8 @@ def test_components_module_from_string() -> None: def get_py_components_config_dict( *, failing_components: bool = False, - needs_secrets: bool = True, ) -> dict[str, Any]: - connector_dir = Path(get_fixture_path("resources/source_the_guardian_api")) + connector_dir = Path(get_fixture_path("resources/source_pokeapi_w_components_py")) manifest_yml_path: Path = connector_dir / "manifest.yaml" custom_py_code_path: Path = connector_dir / ( "components.py" if not failing_components else "components_failing.py" @@ -115,9 +114,6 @@ def get_py_components_config_dict( }, } combined_config_dict.update(yaml.safe_load(config_yaml_path.read_text())) - if needs_secrets: - combined_config_dict.update(yaml.safe_load(secrets_yaml_path.read_text())) - return combined_config_dict @@ -127,9 +123,7 @@ def test_missing_checksum_fails_to_run( """Assert that missing checksum in the config will raise an error.""" monkeypatch.setenv(ENV_VAR_ALLOW_CUSTOM_CODE, "true") - py_components_config_dict = get_py_components_config_dict( - needs_secrets=False, - ) + py_components_config_dict = get_py_components_config_dict() # Truncate the start_date to speed up tests py_components_config_dict["start_date"] = ( datetime.datetime.now() - datetime.timedelta(days=2) @@ -161,9 +155,7 @@ def test_invalid_checksum_fails_to_run( """Assert that an invalid checksum in the config will raise an error.""" monkeypatch.setenv(ENV_VAR_ALLOW_CUSTOM_CODE, "true") - py_components_config_dict = get_py_components_config_dict( - needs_secrets=False, - ) + py_components_config_dict = get_py_components_config_dict() # Truncate the start_date to speed up tests py_components_config_dict["start_date"] = ( datetime.datetime.now() - datetime.timedelta(days=2) @@ -210,9 +202,7 @@ def test_fail_unless_custom_code_enabled_explicitly( assert custom_code_execution_permitted() == (not should_raise) - py_components_config_dict = get_py_components_config_dict( - needs_secrets=False, - ) + py_components_config_dict = get_py_components_config_dict() # Truncate the start_date to speed up tests py_components_config_dict["start_date"] = ( datetime.datetime.now() - datetime.timedelta(days=2) @@ -234,11 +224,6 @@ def test_fail_unless_custom_code_enabled_explicitly( fn() -# TODO: Create a new test source that doesn't require credentials to run. -@pytest.mark.skipif( - condition=not Path(get_fixture_path("resources/source_the_guardian_api/secrets.yaml")).exists(), - reason="Skipped due to missing 'secrets.yaml'.", -) @pytest.mark.parametrize( "failing_components", [ @@ -288,17 +273,19 @@ def test_sync_with_injected_py_components( ] ) - msg_iterator = source.read( - logger=logging.getLogger(), - config=py_components_config_dict, - catalog=configured_catalog, - state=None, - ) - if failing_components: - with pytest.raises(Exception): - for msg in msg_iterator: - assert msg + def _read_fn(*args, **kwargs): + msg_iterator = source.read( + logger=logging.getLogger(), + config=py_components_config_dict, + catalog=configured_catalog, + state=None, + ) + for msg in msg_iterator: + assert msg return - for msg in msg_iterator: - assert msg + if failing_components: + with pytest.raises(Exception): + _read_fn() + else: + _read_fn()