@@ -95,6 +95,10 @@ def __init__(
95
95
# the oldest partitions can be efficiently removed, maintaining the most recent partitions.
96
96
self ._cursor_per_partition : OrderedDict [str , ConcurrentCursor ] = OrderedDict ()
97
97
self ._semaphore_per_partition : OrderedDict [str , threading .Semaphore ] = OrderedDict ()
98
+
99
+ # Parent-state tracking: store each partition’s parent state in creation order
100
+ self ._partition_parent_state_map : OrderedDict [str , Mapping [str , Any ]] = OrderedDict ()
101
+
98
102
self ._finished_partitions : set [str ] = set ()
99
103
self ._lock = threading .Lock ()
100
104
self ._timer = Timer ()
@@ -155,11 +159,62 @@ def close_partition(self, partition: Partition) -> None:
155
159
and self ._semaphore_per_partition [partition_key ]._value == 0
156
160
):
157
161
self ._update_global_cursor (cursor .state [self .cursor_field .cursor_field_key ])
158
- self ._emit_state_message ()
162
+
163
+ self ._check_and_update_parent_state ()
164
+
165
+ self ._emit_state_message ()
166
+
167
+ def _check_and_update_parent_state (self ) -> None :
168
+ """
169
+ Pop the leftmost partition state from _partition_parent_state_map only if
170
+ *all partitions* up to (and including) that partition key in _semaphore_per_partition
171
+ are fully finished (i.e. in _finished_partitions and semaphore._value == 0).
172
+ Additionally, delete finished semaphores with a value of 0 to free up memory,
173
+ as they are only needed to track errors and completion status.
174
+ """
175
+ last_closed_state = None
176
+
177
+ while self ._partition_parent_state_map :
178
+ # Look at the earliest partition key in creation order
179
+ earliest_key = next (iter (self ._partition_parent_state_map ))
180
+
181
+ # Verify ALL partitions from the left up to earliest_key are finished
182
+ all_left_finished = True
183
+ for p_key , sem in list (
184
+ self ._semaphore_per_partition .items ()
185
+ ): # Use list to allow modification during iteration
186
+ # If any earlier partition is still not finished, we must stop
187
+ if p_key not in self ._finished_partitions or sem ._value != 0 :
188
+ all_left_finished = False
189
+ break
190
+ # Once we've reached earliest_key in the semaphore order, we can stop checking
191
+ if p_key == earliest_key :
192
+ break
193
+
194
+ # If the partitions up to earliest_key are not all finished, break the while-loop
195
+ if not all_left_finished :
196
+ break
197
+
198
+ # Pop the leftmost entry from parent-state map
199
+ _ , closed_parent_state = self ._partition_parent_state_map .popitem (last = False )
200
+ last_closed_state = closed_parent_state
201
+
202
+ # Clean up finished semaphores with value 0 up to and including earliest_key
203
+ for p_key in list (self ._semaphore_per_partition .keys ()):
204
+ sem = self ._semaphore_per_partition [p_key ]
205
+ if p_key in self ._finished_partitions and sem ._value == 0 :
206
+ del self ._semaphore_per_partition [p_key ]
207
+ logger .debug (f"Deleted finished semaphore for partition { p_key } with value 0" )
208
+ if p_key == earliest_key :
209
+ break
210
+
211
+ # Update _parent_state if we popped at least one partition
212
+ if last_closed_state is not None :
213
+ self ._parent_state = last_closed_state
159
214
160
215
def ensure_at_least_one_state_emitted (self ) -> None :
161
216
"""
162
- The platform expect to have at least one state message on successful syncs. Hence, whatever happens, we expect this method to be
217
+ The platform expects at least one state message on successful syncs. Hence, whatever happens, we expect this method to be
163
218
called.
164
219
"""
165
220
if not any (
@@ -201,13 +256,19 @@ def stream_slices(self) -> Iterable[StreamSlice]:
201
256
202
257
slices = self ._partition_router .stream_slices ()
203
258
self ._timer .start ()
204
- for partition in slices :
205
- yield from self ._generate_slices_from_partition (partition )
259
+ for partition , last , parent_state in iterate_with_last_flag_and_state (
260
+ slices , self ._partition_router .get_stream_state
261
+ ):
262
+ yield from self ._generate_slices_from_partition (partition , parent_state )
206
263
207
- def _generate_slices_from_partition (self , partition : StreamSlice ) -> Iterable [StreamSlice ]:
264
+ def _generate_slices_from_partition (
265
+ self , partition : StreamSlice , parent_state : Mapping [str , Any ]
266
+ ) -> Iterable [StreamSlice ]:
208
267
# Ensure the maximum number of partitions is not exceeded
209
268
self ._ensure_partition_limit ()
210
269
270
+ partition_key = self ._to_partition_key (partition .partition )
271
+
211
272
cursor = self ._cursor_per_partition .get (self ._to_partition_key (partition .partition ))
212
273
if not cursor :
213
274
cursor = self ._create_cursor (
@@ -216,18 +277,26 @@ def _generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[St
216
277
)
217
278
with self ._lock :
218
279
self ._number_of_partitions += 1
219
- self ._cursor_per_partition [self ._to_partition_key (partition .partition )] = cursor
220
- self ._semaphore_per_partition [self ._to_partition_key (partition .partition )] = (
221
- threading .Semaphore (0 )
222
- )
280
+ self ._cursor_per_partition [partition_key ] = cursor
281
+ self ._semaphore_per_partition [partition_key ] = threading .Semaphore (0 )
282
+
283
+ with self ._lock :
284
+ if (
285
+ len (self ._partition_parent_state_map ) == 0
286
+ or self ._partition_parent_state_map [
287
+ next (reversed (self ._partition_parent_state_map ))
288
+ ]
289
+ != parent_state
290
+ ):
291
+ self ._partition_parent_state_map [partition_key ] = deepcopy (parent_state )
223
292
224
293
for cursor_slice , is_last_slice , _ in iterate_with_last_flag_and_state (
225
294
cursor .stream_slices (),
226
295
lambda : None ,
227
296
):
228
- self ._semaphore_per_partition [self . _to_partition_key ( partition . partition ) ].release ()
297
+ self ._semaphore_per_partition [partition_key ].release ()
229
298
if is_last_slice :
230
- self ._finished_partitions .add (self . _to_partition_key ( partition . partition ) )
299
+ self ._finished_partitions .add (partition_key )
231
300
yield StreamSlice (
232
301
partition = partition , cursor_slice = cursor_slice , extra_fields = partition .extra_fields
233
302
)
@@ -257,9 +326,9 @@ def _ensure_partition_limit(self) -> None:
257
326
while len (self ._cursor_per_partition ) > self .DEFAULT_MAX_PARTITIONS_NUMBER - 1 :
258
327
# Try removing finished partitions first
259
328
for partition_key in list (self ._cursor_per_partition .keys ()):
260
- if (
261
- partition_key in self ._finished_partitions
262
- and self ._semaphore_per_partition [partition_key ]._value == 0
329
+ if partition_key in self . _finished_partitions and (
330
+ partition_key not in self ._semaphore_per_partition
331
+ or self ._semaphore_per_partition [partition_key ]._value == 0
263
332
):
264
333
oldest_partition = self ._cursor_per_partition .pop (
265
334
partition_key
@@ -338,9 +407,6 @@ def _set_initial_state(self, stream_state: StreamState) -> None:
338
407
self ._cursor_per_partition [self ._to_partition_key (state ["partition" ])] = (
339
408
self ._create_cursor (state ["cursor" ])
340
409
)
341
- self ._semaphore_per_partition [self ._to_partition_key (state ["partition" ])] = (
342
- threading .Semaphore (0 )
343
- )
344
410
345
411
# set default state for missing partitions if it is per partition with fallback to global
346
412
if self ._GLOBAL_STATE_KEY in stream_state :
0 commit comments