4
4
5
5
import datetime
6
6
from dataclasses import InitVar , dataclass , field
7
- from typing import Any , Iterable , List , Mapping , Optional , Union
7
+ from typing import Any , Callable , Iterable , List , Mapping , MutableMapping , Optional , Union
8
8
9
9
from airbyte_cdk .models import AirbyteLogMessage , AirbyteMessage , Level , Type
10
10
from airbyte_cdk .sources .declarative .datetime .datetime_parser import DatetimeParser
@@ -70,10 +70,8 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
70
70
f"If step is defined, cursor_granularity should be as well and vice-versa. "
71
71
f"Right now, step is `{ self .step } ` and cursor_granularity is `{ self .cursor_granularity } `"
72
72
)
73
- if not isinstance (self .start_datetime , MinMaxDatetime ):
74
- self .start_datetime = MinMaxDatetime (self .start_datetime , parameters )
75
- if self .end_datetime and not isinstance (self .end_datetime , MinMaxDatetime ):
76
- self .end_datetime = MinMaxDatetime (self .end_datetime , parameters )
73
+ self ._start_datetime = MinMaxDatetime .create (self .start_datetime , parameters )
74
+ self ._end_datetime = None if not self .end_datetime else MinMaxDatetime .create (self .end_datetime , parameters )
77
75
78
76
self ._timezone = datetime .timezone .utc
79
77
self ._interpolation = JinjaInterpolation ()
@@ -84,23 +82,23 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
84
82
else datetime .timedelta .max
85
83
)
86
84
self ._cursor_granularity = self ._parse_timedelta (self .cursor_granularity )
87
- self .cursor_field = InterpolatedString .create (self .cursor_field , parameters = parameters )
88
- self .lookback_window = InterpolatedString .create (self .lookback_window , parameters = parameters )
89
- self .partition_field_start = InterpolatedString .create (self .partition_field_start or "start_time" , parameters = parameters )
90
- self .partition_field_end = InterpolatedString .create (self .partition_field_end or "end_time" , parameters = parameters )
85
+ self ._cursor_field = InterpolatedString .create (self .cursor_field , parameters = parameters )
86
+ self ._lookback_window = InterpolatedString .create (self .lookback_window , parameters = parameters ) if self . lookback_window else None
87
+ self ._partition_field_start = InterpolatedString .create (self .partition_field_start or "start_time" , parameters = parameters )
88
+ self ._partition_field_end = InterpolatedString .create (self .partition_field_end or "end_time" , parameters = parameters )
91
89
self ._parser = DatetimeParser ()
92
90
93
91
# If datetime format is not specified then start/end datetime should inherit it from the stream slicer
94
- if not self .start_datetime .datetime_format :
95
- self .start_datetime .datetime_format = self .datetime_format
96
- if self .end_datetime and not self .end_datetime .datetime_format :
97
- self .end_datetime .datetime_format = self .datetime_format
92
+ if not self ._start_datetime .datetime_format :
93
+ self ._start_datetime .datetime_format = self .datetime_format
94
+ if self ._end_datetime and not self ._end_datetime .datetime_format :
95
+ self ._end_datetime .datetime_format = self .datetime_format
98
96
99
97
if not self .cursor_datetime_formats :
100
98
self .cursor_datetime_formats = [self .datetime_format ]
101
99
102
100
def get_stream_state (self ) -> StreamState :
103
- return {self .cursor_field .eval (self .config ): self ._cursor } if self ._cursor else {}
101
+ return {self ._cursor_field .eval (self .config ): self ._cursor } if self ._cursor else {}
104
102
105
103
def set_initial_state (self , stream_state : StreamState ) -> None :
106
104
"""
@@ -109,17 +107,22 @@ def set_initial_state(self, stream_state: StreamState) -> None:
109
107
110
108
:param stream_state: The state of the stream as returned by get_stream_state
111
109
"""
112
- self ._cursor = stream_state .get (self .cursor_field .eval (self .config )) if stream_state else None
110
+ self ._cursor = stream_state .get (self ._cursor_field .eval (self .config )) if stream_state else None
113
111
114
112
def close_slice (self , stream_slice : StreamSlice , most_recent_record : Optional [Record ]) -> None :
115
- last_record_cursor_value = most_recent_record .get (self .cursor_field .eval (self .config )) if most_recent_record else None
116
- stream_slice_value_end = stream_slice .get (self .partition_field_end .eval (self .config ))
113
+ if stream_slice .partition :
114
+ raise ValueError (f"Stream slice { stream_slice } should not have a partition. Got { stream_slice .partition } ." )
115
+ last_record_cursor_value = most_recent_record .get (self ._cursor_field .eval (self .config )) if most_recent_record else None
116
+ stream_slice_value_end = stream_slice .get (self ._partition_field_end .eval (self .config ))
117
+ potential_cursor_values = [
118
+ cursor_value for cursor_value in [self ._cursor , last_record_cursor_value , stream_slice_value_end ] if cursor_value
119
+ ]
117
120
cursor_value_str_by_cursor_value_datetime = dict (
118
121
map (
119
122
# we need to ensure the cursor value is preserved as is in the state else the CATs might complain of something like
120
123
# 2023-01-04T17:30:19.000Z' <= '2023-01-04T17:30:19.000000Z'
121
124
lambda datetime_str : (self .parse_date (datetime_str ), datetime_str ),
122
- filter ( lambda item : item , [ self . _cursor , last_record_cursor_value , stream_slice_value_end ]) ,
125
+ potential_cursor_values ,
123
126
)
124
127
)
125
128
self ._cursor = (
@@ -142,37 +145,43 @@ def stream_slices(self) -> Iterable[StreamSlice]:
142
145
return self ._partition_daterange (start_datetime , end_datetime , self ._step )
143
146
144
147
def _calculate_earliest_possible_value (self , end_datetime : datetime .datetime ) -> datetime .datetime :
145
- lookback_delta = self ._parse_timedelta (self .lookback_window .eval (self .config ) if self .lookback_window else "P0D" )
146
- earliest_possible_start_datetime = min (self .start_datetime .get_datetime (self .config ), end_datetime )
148
+ lookback_delta = self ._parse_timedelta (self ._lookback_window .eval (self .config ) if self .lookback_window else "P0D" )
149
+ earliest_possible_start_datetime = min (self ._start_datetime .get_datetime (self .config ), end_datetime )
147
150
cursor_datetime = self ._calculate_cursor_datetime_from_state (self .get_stream_state ())
148
151
return max (earliest_possible_start_datetime , cursor_datetime ) - lookback_delta
149
152
150
153
def _select_best_end_datetime (self ) -> datetime .datetime :
151
154
now = datetime .datetime .now (tz = self ._timezone )
152
- if not self .end_datetime :
155
+ if not self ._end_datetime :
153
156
return now
154
- return min (self .end_datetime .get_datetime (self .config ), now )
157
+ return min (self ._end_datetime .get_datetime (self .config ), now )
155
158
156
159
def _calculate_cursor_datetime_from_state (self , stream_state : Mapping [str , Any ]) -> datetime .datetime :
157
- if self .cursor_field .eval (self .config , stream_state = stream_state ) in stream_state :
158
- return self .parse_date (stream_state [self .cursor_field .eval (self .config )])
160
+ if self ._cursor_field .eval (self .config , stream_state = stream_state ) in stream_state :
161
+ return self .parse_date (stream_state [self ._cursor_field .eval (self .config )])
159
162
return datetime .datetime .min .replace (tzinfo = datetime .timezone .utc )
160
163
161
164
def _format_datetime (self , dt : datetime .datetime ) -> str :
162
165
return self ._parser .format (dt , self .datetime_format )
163
166
164
- def _partition_daterange (self , start : datetime .datetime , end : datetime .datetime , step : Union [datetime .timedelta , Duration ]):
165
- start_field = self .partition_field_start .eval (self .config )
166
- end_field = self .partition_field_end .eval (self .config )
167
+ def _partition_daterange (
168
+ self , start : datetime .datetime , end : datetime .datetime , step : Union [datetime .timedelta , Duration ]
169
+ ) -> List [StreamSlice ]:
170
+ start_field = self ._partition_field_start .eval (self .config )
171
+ end_field = self ._partition_field_end .eval (self .config )
167
172
dates = []
168
173
while start <= end :
169
174
next_start = self ._evaluate_next_start_date_safely (start , step )
170
175
end_date = self ._get_date (next_start - self ._cursor_granularity , end , min )
171
- dates .append ({start_field : self ._format_datetime (start ), end_field : self ._format_datetime (end_date )})
176
+ dates .append (
177
+ StreamSlice (
178
+ partition = {}, cursor_slice = {start_field : self ._format_datetime (start ), end_field : self ._format_datetime (end_date )}
179
+ )
180
+ )
172
181
start = next_start
173
182
return dates
174
183
175
- def _evaluate_next_start_date_safely (self , start , step ) :
184
+ def _evaluate_next_start_date_safely (self , start : datetime . datetime , step : datetime . timedelta ) -> datetime . datetime :
176
185
"""
177
186
Given that we set the default step at datetime.timedelta.max, we will generate an OverflowError when evaluating the next start_date
178
187
This method assumes that users would never enter a step that would generate an overflow. Given that would be the case, the code
@@ -183,7 +192,12 @@ def _evaluate_next_start_date_safely(self, start, step):
183
192
except OverflowError :
184
193
return datetime .datetime .max .replace (tzinfo = datetime .timezone .utc )
185
194
186
- def _get_date (self , cursor_value , default_date : datetime .datetime , comparator ) -> datetime .datetime :
195
+ def _get_date (
196
+ self ,
197
+ cursor_value : datetime .datetime ,
198
+ default_date : datetime .datetime ,
199
+ comparator : Callable [[datetime .datetime , datetime .datetime ], datetime .datetime ],
200
+ ) -> datetime .datetime :
187
201
cursor_date = cursor_value or default_date
188
202
return comparator (cursor_date , default_date )
189
203
@@ -196,7 +210,7 @@ def parse_date(self, date: str) -> datetime.datetime:
196
210
raise ValueError (f"No format in { self .cursor_datetime_formats } matching { date } " )
197
211
198
212
@classmethod
199
- def _parse_timedelta (cls , time_str ) -> Union [datetime .timedelta , Duration ]:
213
+ def _parse_timedelta (cls , time_str : Optional [ str ] ) -> Union [datetime .timedelta , Duration ]:
200
214
"""
201
215
:return Parses an ISO 8601 durations into datetime.timedelta or Duration objects.
202
216
"""
@@ -244,18 +258,20 @@ def request_kwargs(self) -> Mapping[str, Any]:
244
258
# Never update kwargs
245
259
return {}
246
260
247
- def _get_request_options (self , option_type : RequestOptionType , stream_slice : StreamSlice ):
248
- options = {}
261
+ def _get_request_options (self , option_type : RequestOptionType , stream_slice : Optional [StreamSlice ]) -> Mapping [str , Any ]:
262
+ options : MutableMapping [str , Any ] = {}
263
+ if not stream_slice :
264
+ return options
249
265
if self .start_time_option and self .start_time_option .inject_into == option_type :
250
- options [self .start_time_option .field_name .eval (config = self .config )] = stream_slice .get (
251
- self .partition_field_start .eval (self .config )
266
+ options [self .start_time_option .field_name .eval (config = self .config )] = stream_slice .get ( # type: ignore # field_name is always casted to an interpolated string
267
+ self ._partition_field_start .eval (self .config )
252
268
)
253
269
if self .end_time_option and self .end_time_option .inject_into == option_type :
254
- options [self .end_time_option .field_name .eval (config = self .config )] = stream_slice .get (self .partition_field_end .eval (self .config ))
270
+ options [self .end_time_option .field_name .eval (config = self .config )] = stream_slice .get (self ._partition_field_end .eval (self .config )) # type: ignore # field_name is always casted to an interpolated string
255
271
return options
256
272
257
273
def should_be_synced (self , record : Record ) -> bool :
258
- cursor_field = self .cursor_field .eval (self .config )
274
+ cursor_field = self ._cursor_field .eval (self .config )
259
275
record_cursor_value = record .get (cursor_field )
260
276
if not record_cursor_value :
261
277
self ._send_log (
@@ -278,7 +294,7 @@ def _send_log(self, level: Level, message: str) -> None:
278
294
)
279
295
280
296
def is_greater_than_or_equal (self , first : Record , second : Record ) -> bool :
281
- cursor_field = self .cursor_field .eval (self .config )
297
+ cursor_field = self ._cursor_field .eval (self .config )
282
298
first_cursor_value = first .get (cursor_field )
283
299
second_cursor_value = second .get (cursor_field )
284
300
if first_cursor_value and second_cursor_value :
0 commit comments