Skip to content

Commit f55abc1

Browse files
authored
🐛 low-code: Fix incremental substreams (#35471)
1 parent 95afe28 commit f55abc1

23 files changed

+833
-298
lines changed

airbyte-cdk/python/airbyte_cdk/sources/declarative/datetime/min_max_datetime.py

+29-14
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import datetime as dt
66
from dataclasses import InitVar, dataclass, field
7-
from typing import Any, Mapping, Union
7+
from typing import Any, Mapping, Optional, Union
88

99
from airbyte_cdk.sources.declarative.datetime.datetime_parser import DatetimeParser
1010
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
@@ -37,13 +37,13 @@ class MinMaxDatetime:
3737
min_datetime: Union[InterpolatedString, str] = ""
3838
max_datetime: Union[InterpolatedString, str] = ""
3939

40-
def __post_init__(self, parameters: Mapping[str, Any]):
40+
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
4141
self.datetime = InterpolatedString.create(self.datetime, parameters=parameters or {})
4242
self._parser = DatetimeParser()
43-
self.min_datetime = InterpolatedString.create(self.min_datetime, parameters=parameters) if self.min_datetime else None
44-
self.max_datetime = InterpolatedString.create(self.max_datetime, parameters=parameters) if self.max_datetime else None
43+
self.min_datetime = InterpolatedString.create(self.min_datetime, parameters=parameters) if self.min_datetime else None # type: ignore
44+
self.max_datetime = InterpolatedString.create(self.max_datetime, parameters=parameters) if self.max_datetime else None # type: ignore
4545

46-
def get_datetime(self, config, **additional_parameters) -> dt.datetime:
46+
def get_datetime(self, config: Mapping[str, Any], **additional_parameters: Mapping[str, Any]) -> dt.datetime:
4747
"""
4848
Evaluates and returns the datetime
4949
:param config: The user-provided configuration as specified by the source's spec
@@ -55,29 +55,44 @@ def get_datetime(self, config, **additional_parameters) -> dt.datetime:
5555
if not datetime_format:
5656
datetime_format = "%Y-%m-%dT%H:%M:%S.%f%z"
5757

58-
time = self._parser.parse(str(self.datetime.eval(config, **additional_parameters)), datetime_format)
58+
time = self._parser.parse(str(self.datetime.eval(config, **additional_parameters)), datetime_format) # type: ignore # datetime is always cast to an interpolated string
5959

6060
if self.min_datetime:
61-
min_time = str(self.min_datetime.eval(config, **additional_parameters))
61+
min_time = str(self.min_datetime.eval(config, **additional_parameters)) # type: ignore # min_datetime is always cast to an interpolated string
6262
if min_time:
63-
min_time = self._parser.parse(min_time, datetime_format)
64-
time = max(time, min_time)
63+
min_datetime = self._parser.parse(min_time, datetime_format) # type: ignore # min_datetime is always cast to an interpolated string
64+
time = max(time, min_datetime)
6565
if self.max_datetime:
66-
max_time = str(self.max_datetime.eval(config, **additional_parameters))
66+
max_time = str(self.max_datetime.eval(config, **additional_parameters)) # type: ignore # max_datetime is always cast to an interpolated string
6767
if max_time:
68-
max_time = self._parser.parse(max_time, datetime_format)
69-
time = min(time, max_time)
68+
max_datetime = self._parser.parse(max_time, datetime_format)
69+
time = min(time, max_datetime)
7070
return time
7171

72-
@property
72+
@property # type: ignore # properties don't play well with dataclasses...
7373
def datetime_format(self) -> str:
7474
"""The format of the string representing the datetime"""
7575
return self._datetime_format
7676

7777
@datetime_format.setter
78-
def datetime_format(self, value: str):
78+
def datetime_format(self, value: str) -> None:
7979
"""Setter for the datetime format"""
8080
# Covers the case where datetime_format is not provided in the constructor, which causes the property object
8181
# to be set which we need to avoid doing
8282
if not isinstance(value, property):
8383
self._datetime_format = value
84+
85+
@classmethod
86+
def create(
87+
cls,
88+
interpolated_string_or_min_max_datetime: Union[InterpolatedString, str, "MinMaxDatetime"],
89+
parameters: Optional[Mapping[str, Any]] = None,
90+
) -> "MinMaxDatetime":
91+
if parameters is None:
92+
parameters = {}
93+
if isinstance(interpolated_string_or_min_max_datetime, InterpolatedString) or isinstance(
94+
interpolated_string_or_min_max_datetime, str
95+
):
96+
return MinMaxDatetime(datetime=interpolated_string_or_min_max_datetime, parameters=parameters)
97+
else:
98+
return interpolated_string_or_min_max_datetime

airbyte-cdk/python/airbyte_cdk/sources/declarative/declarative_stream.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from airbyte_cdk.sources.declarative.retrievers.retriever import Retriever
1111
from airbyte_cdk.sources.declarative.schema import DefaultSchemaLoader
1212
from airbyte_cdk.sources.declarative.schema.schema_loader import SchemaLoader
13-
from airbyte_cdk.sources.declarative.types import Config
13+
from airbyte_cdk.sources.declarative.types import Config, StreamSlice
1414
from airbyte_cdk.sources.streams.core import Stream
1515

1616

@@ -101,6 +101,8 @@ def read_records(
101101
"""
102102
:param: stream_state We knowingly avoid using stream_state as we want cursors to manage their own state.
103103
"""
104+
if not isinstance(stream_slice, StreamSlice):
105+
raise ValueError(f"DeclarativeStream does not support stream_slices that are not StreamSlice. Got {stream_slice}")
104106
yield from self.retriever.read_records(self.get_json_schema(), stream_slice)
105107

106108
def get_json_schema(self) -> Mapping[str, Any]: # type: ignore
@@ -114,7 +116,7 @@ def get_json_schema(self) -> Mapping[str, Any]: # type: ignore
114116

115117
def stream_slices(
116118
self, *, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None
117-
) -> Iterable[Optional[Mapping[str, Any]]]:
119+
) -> Iterable[Optional[StreamSlice]]:
118120
"""
119121
Override to define the slices for this stream. See the stream slicing section of the docs for more information.
120122

airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py

+54-38
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import datetime
66
from dataclasses import InitVar, dataclass, field
7-
from typing import Any, Iterable, List, Mapping, Optional, Union
7+
from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Union
88

99
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, Type
1010
from airbyte_cdk.sources.declarative.datetime.datetime_parser import DatetimeParser
@@ -70,10 +70,8 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
7070
f"If step is defined, cursor_granularity should be as well and vice-versa. "
7171
f"Right now, step is `{self.step}` and cursor_granularity is `{self.cursor_granularity}`"
7272
)
73-
if not isinstance(self.start_datetime, MinMaxDatetime):
74-
self.start_datetime = MinMaxDatetime(self.start_datetime, parameters)
75-
if self.end_datetime and not isinstance(self.end_datetime, MinMaxDatetime):
76-
self.end_datetime = MinMaxDatetime(self.end_datetime, parameters)
73+
self._start_datetime = MinMaxDatetime.create(self.start_datetime, parameters)
74+
self._end_datetime = None if not self.end_datetime else MinMaxDatetime.create(self.end_datetime, parameters)
7775

7876
self._timezone = datetime.timezone.utc
7977
self._interpolation = JinjaInterpolation()
@@ -84,23 +82,23 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
8482
else datetime.timedelta.max
8583
)
8684
self._cursor_granularity = self._parse_timedelta(self.cursor_granularity)
87-
self.cursor_field = InterpolatedString.create(self.cursor_field, parameters=parameters)
88-
self.lookback_window = InterpolatedString.create(self.lookback_window, parameters=parameters)
89-
self.partition_field_start = InterpolatedString.create(self.partition_field_start or "start_time", parameters=parameters)
90-
self.partition_field_end = InterpolatedString.create(self.partition_field_end or "end_time", parameters=parameters)
85+
self._cursor_field = InterpolatedString.create(self.cursor_field, parameters=parameters)
86+
self._lookback_window = InterpolatedString.create(self.lookback_window, parameters=parameters) if self.lookback_window else None
87+
self._partition_field_start = InterpolatedString.create(self.partition_field_start or "start_time", parameters=parameters)
88+
self._partition_field_end = InterpolatedString.create(self.partition_field_end or "end_time", parameters=parameters)
9189
self._parser = DatetimeParser()
9290

9391
# If datetime format is not specified then start/end datetime should inherit it from the stream slicer
94-
if not self.start_datetime.datetime_format:
95-
self.start_datetime.datetime_format = self.datetime_format
96-
if self.end_datetime and not self.end_datetime.datetime_format:
97-
self.end_datetime.datetime_format = self.datetime_format
92+
if not self._start_datetime.datetime_format:
93+
self._start_datetime.datetime_format = self.datetime_format
94+
if self._end_datetime and not self._end_datetime.datetime_format:
95+
self._end_datetime.datetime_format = self.datetime_format
9896

9997
if not self.cursor_datetime_formats:
10098
self.cursor_datetime_formats = [self.datetime_format]
10199

102100
def get_stream_state(self) -> StreamState:
103-
return {self.cursor_field.eval(self.config): self._cursor} if self._cursor else {}
101+
return {self._cursor_field.eval(self.config): self._cursor} if self._cursor else {}
104102

105103
def set_initial_state(self, stream_state: StreamState) -> None:
106104
"""
@@ -109,17 +107,22 @@ def set_initial_state(self, stream_state: StreamState) -> None:
109107
110108
:param stream_state: The state of the stream as returned by get_stream_state
111109
"""
112-
self._cursor = stream_state.get(self.cursor_field.eval(self.config)) if stream_state else None
110+
self._cursor = stream_state.get(self._cursor_field.eval(self.config)) if stream_state else None
113111

114112
def close_slice(self, stream_slice: StreamSlice, most_recent_record: Optional[Record]) -> None:
115-
last_record_cursor_value = most_recent_record.get(self.cursor_field.eval(self.config)) if most_recent_record else None
116-
stream_slice_value_end = stream_slice.get(self.partition_field_end.eval(self.config))
113+
if stream_slice.partition:
114+
raise ValueError(f"Stream slice {stream_slice} should not have a partition. Got {stream_slice.partition}.")
115+
last_record_cursor_value = most_recent_record.get(self._cursor_field.eval(self.config)) if most_recent_record else None
116+
stream_slice_value_end = stream_slice.get(self._partition_field_end.eval(self.config))
117+
potential_cursor_values = [
118+
cursor_value for cursor_value in [self._cursor, last_record_cursor_value, stream_slice_value_end] if cursor_value
119+
]
117120
cursor_value_str_by_cursor_value_datetime = dict(
118121
map(
119122
# we need to ensure the cursor value is preserved as is in the state else the CATs might complain of something like
120123
# 2023-01-04T17:30:19.000Z' <= '2023-01-04T17:30:19.000000Z'
121124
lambda datetime_str: (self.parse_date(datetime_str), datetime_str),
122-
filter(lambda item: item, [self._cursor, last_record_cursor_value, stream_slice_value_end]),
125+
potential_cursor_values,
123126
)
124127
)
125128
self._cursor = (
@@ -142,37 +145,43 @@ def stream_slices(self) -> Iterable[StreamSlice]:
142145
return self._partition_daterange(start_datetime, end_datetime, self._step)
143146

144147
def _calculate_earliest_possible_value(self, end_datetime: datetime.datetime) -> datetime.datetime:
145-
lookback_delta = self._parse_timedelta(self.lookback_window.eval(self.config) if self.lookback_window else "P0D")
146-
earliest_possible_start_datetime = min(self.start_datetime.get_datetime(self.config), end_datetime)
148+
lookback_delta = self._parse_timedelta(self._lookback_window.eval(self.config) if self.lookback_window else "P0D")
149+
earliest_possible_start_datetime = min(self._start_datetime.get_datetime(self.config), end_datetime)
147150
cursor_datetime = self._calculate_cursor_datetime_from_state(self.get_stream_state())
148151
return max(earliest_possible_start_datetime, cursor_datetime) - lookback_delta
149152

150153
def _select_best_end_datetime(self) -> datetime.datetime:
151154
now = datetime.datetime.now(tz=self._timezone)
152-
if not self.end_datetime:
155+
if not self._end_datetime:
153156
return now
154-
return min(self.end_datetime.get_datetime(self.config), now)
157+
return min(self._end_datetime.get_datetime(self.config), now)
155158

156159
def _calculate_cursor_datetime_from_state(self, stream_state: Mapping[str, Any]) -> datetime.datetime:
157-
if self.cursor_field.eval(self.config, stream_state=stream_state) in stream_state:
158-
return self.parse_date(stream_state[self.cursor_field.eval(self.config)])
160+
if self._cursor_field.eval(self.config, stream_state=stream_state) in stream_state:
161+
return self.parse_date(stream_state[self._cursor_field.eval(self.config)])
159162
return datetime.datetime.min.replace(tzinfo=datetime.timezone.utc)
160163

161164
def _format_datetime(self, dt: datetime.datetime) -> str:
162165
return self._parser.format(dt, self.datetime_format)
163166

164-
def _partition_daterange(self, start: datetime.datetime, end: datetime.datetime, step: Union[datetime.timedelta, Duration]):
165-
start_field = self.partition_field_start.eval(self.config)
166-
end_field = self.partition_field_end.eval(self.config)
167+
def _partition_daterange(
168+
self, start: datetime.datetime, end: datetime.datetime, step: Union[datetime.timedelta, Duration]
169+
) -> List[StreamSlice]:
170+
start_field = self._partition_field_start.eval(self.config)
171+
end_field = self._partition_field_end.eval(self.config)
167172
dates = []
168173
while start <= end:
169174
next_start = self._evaluate_next_start_date_safely(start, step)
170175
end_date = self._get_date(next_start - self._cursor_granularity, end, min)
171-
dates.append({start_field: self._format_datetime(start), end_field: self._format_datetime(end_date)})
176+
dates.append(
177+
StreamSlice(
178+
partition={}, cursor_slice={start_field: self._format_datetime(start), end_field: self._format_datetime(end_date)}
179+
)
180+
)
172181
start = next_start
173182
return dates
174183

175-
def _evaluate_next_start_date_safely(self, start, step):
184+
def _evaluate_next_start_date_safely(self, start: datetime.datetime, step: datetime.timedelta) -> datetime.datetime:
176185
"""
177186
Given that we set the default step at datetime.timedelta.max, we will generate an OverflowError when evaluating the next start_date
178187
This method assumes that users would never enter a step that would generate an overflow. Given that would be the case, the code
@@ -183,7 +192,12 @@ def _evaluate_next_start_date_safely(self, start, step):
183192
except OverflowError:
184193
return datetime.datetime.max.replace(tzinfo=datetime.timezone.utc)
185194

186-
def _get_date(self, cursor_value, default_date: datetime.datetime, comparator) -> datetime.datetime:
195+
def _get_date(
196+
self,
197+
cursor_value: datetime.datetime,
198+
default_date: datetime.datetime,
199+
comparator: Callable[[datetime.datetime, datetime.datetime], datetime.datetime],
200+
) -> datetime.datetime:
187201
cursor_date = cursor_value or default_date
188202
return comparator(cursor_date, default_date)
189203

@@ -196,7 +210,7 @@ def parse_date(self, date: str) -> datetime.datetime:
196210
raise ValueError(f"No format in {self.cursor_datetime_formats} matching {date}")
197211

198212
@classmethod
199-
def _parse_timedelta(cls, time_str) -> Union[datetime.timedelta, Duration]:
213+
def _parse_timedelta(cls, time_str: Optional[str]) -> Union[datetime.timedelta, Duration]:
200214
"""
201215
:return Parses an ISO 8601 durations into datetime.timedelta or Duration objects.
202216
"""
@@ -244,18 +258,20 @@ def request_kwargs(self) -> Mapping[str, Any]:
244258
# Never update kwargs
245259
return {}
246260

247-
def _get_request_options(self, option_type: RequestOptionType, stream_slice: StreamSlice):
248-
options = {}
261+
def _get_request_options(self, option_type: RequestOptionType, stream_slice: Optional[StreamSlice]) -> Mapping[str, Any]:
262+
options: MutableMapping[str, Any] = {}
263+
if not stream_slice:
264+
return options
249265
if self.start_time_option and self.start_time_option.inject_into == option_type:
250-
options[self.start_time_option.field_name.eval(config=self.config)] = stream_slice.get(
251-
self.partition_field_start.eval(self.config)
266+
options[self.start_time_option.field_name.eval(config=self.config)] = stream_slice.get( # type: ignore # field_name is always casted to an interpolated string
267+
self._partition_field_start.eval(self.config)
252268
)
253269
if self.end_time_option and self.end_time_option.inject_into == option_type:
254-
options[self.end_time_option.field_name.eval(config=self.config)] = stream_slice.get(self.partition_field_end.eval(self.config))
270+
options[self.end_time_option.field_name.eval(config=self.config)] = stream_slice.get(self._partition_field_end.eval(self.config)) # type: ignore # field_name is always casted to an interpolated string
255271
return options
256272

257273
def should_be_synced(self, record: Record) -> bool:
258-
cursor_field = self.cursor_field.eval(self.config)
274+
cursor_field = self._cursor_field.eval(self.config)
259275
record_cursor_value = record.get(cursor_field)
260276
if not record_cursor_value:
261277
self._send_log(
@@ -278,7 +294,7 @@ def _send_log(self, level: Level, message: str) -> None:
278294
)
279295

280296
def is_greater_than_or_equal(self, first: Record, second: Record) -> bool:
281-
cursor_field = self.cursor_field.eval(self.config)
297+
cursor_field = self._cursor_field.eval(self.config)
282298
first_cursor_value = first.get(cursor_field)
283299
second_cursor_value = second.get(cursor_field)
284300
if first_cursor_value and second_cursor_value:

0 commit comments

Comments
 (0)