Skip to content

Commit 1869fa5

Browse files
authored
feat(concurrent perpartition cursor): Refactor ConcurrentPerPartitionCursor (#331)
1 parent 263343a commit 1869fa5

File tree

2 files changed

+114
-33
lines changed

2 files changed

+114
-33
lines changed

Diff for: airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py

+56-25
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import copy
66
import logging
77
import threading
8+
import time
89
from collections import OrderedDict
910
from copy import deepcopy
1011
from datetime import timedelta
@@ -58,7 +59,8 @@ class ConcurrentPerPartitionCursor(Cursor):
5859
CurrentPerPartitionCursor expects the state of the ConcurrentCursor to follow the format {cursor_field: cursor_value}.
5960
"""
6061

61-
DEFAULT_MAX_PARTITIONS_NUMBER = 10000
62+
DEFAULT_MAX_PARTITIONS_NUMBER = 25_000
63+
SWITCH_TO_GLOBAL_LIMIT = 10_000
6264
_NO_STATE: Mapping[str, Any] = {}
6365
_NO_CURSOR_STATE: Mapping[str, Any] = {}
6466
_GLOBAL_STATE_KEY = "state"
@@ -99,9 +101,11 @@ def __init__(
99101
self._new_global_cursor: Optional[StreamState] = None
100102
self._lookback_window: int = 0
101103
self._parent_state: Optional[StreamState] = None
102-
self._over_limit: int = 0
104+
self._number_of_partitions: int = 0
103105
self._use_global_cursor: bool = False
104106
self._partition_serializer = PerPartitionKeySerializer()
107+
# Track the last time a state message was emitted
108+
self._last_emission_time: float = 0.0
105109

106110
self._set_initial_state(stream_state)
107111

@@ -141,21 +145,16 @@ def close_partition(self, partition: Partition) -> None:
141145
raise ValueError("stream_slice cannot be None")
142146

143147
partition_key = self._to_partition_key(stream_slice.partition)
144-
self._cursor_per_partition[partition_key].close_partition(partition=partition)
145148
with self._lock:
146149
self._semaphore_per_partition[partition_key].acquire()
147-
cursor = self._cursor_per_partition[partition_key]
148-
if (
149-
partition_key in self._finished_partitions
150-
and self._semaphore_per_partition[partition_key]._value == 0
151-
):
150+
if not self._use_global_cursor:
151+
self._cursor_per_partition[partition_key].close_partition(partition=partition)
152+
cursor = self._cursor_per_partition[partition_key]
152153
if (
153-
self._new_global_cursor is None
154-
or self._new_global_cursor[self.cursor_field.cursor_field_key]
155-
< cursor.state[self.cursor_field.cursor_field_key]
154+
partition_key in self._finished_partitions
155+
and self._semaphore_per_partition[partition_key]._value == 0
156156
):
157-
self._new_global_cursor = copy.deepcopy(cursor.state)
158-
if not self._use_global_cursor:
157+
self._update_global_cursor(cursor.state[self.cursor_field.cursor_field_key])
159158
self._emit_state_message()
160159

161160
def ensure_at_least_one_state_emitted(self) -> None:
@@ -169,9 +168,23 @@ def ensure_at_least_one_state_emitted(self) -> None:
169168
self._global_cursor = self._new_global_cursor
170169
self._lookback_window = self._timer.finish()
171170
self._parent_state = self._partition_router.get_stream_state()
172-
self._emit_state_message()
171+
self._emit_state_message(throttle=False)
173172

174-
def _emit_state_message(self) -> None:
173+
def _throttle_state_message(self) -> Optional[float]:
174+
"""
175+
Throttles the state message emission to once every 60 seconds.
176+
"""
177+
current_time = time.time()
178+
if current_time - self._last_emission_time <= 60:
179+
return None
180+
return current_time
181+
182+
def _emit_state_message(self, throttle: bool = True) -> None:
183+
if throttle:
184+
current_time = self._throttle_state_message()
185+
if current_time is None:
186+
return
187+
self._last_emission_time = current_time
175188
self._connector_state_manager.update_state_for_stream(
176189
self._stream_name,
177190
self._stream_namespace,
@@ -202,6 +215,7 @@ def _generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[St
202215
self._lookback_window if self._global_cursor else 0,
203216
)
204217
with self._lock:
218+
self._number_of_partitions += 1
205219
self._cursor_per_partition[self._to_partition_key(partition.partition)] = cursor
206220
self._semaphore_per_partition[self._to_partition_key(partition.partition)] = (
207221
threading.Semaphore(0)
@@ -232,9 +246,15 @@ def _ensure_partition_limit(self) -> None:
232246
- Logs a warning each time a partition is removed, indicating whether it was finished
233247
or removed due to being the oldest.
234248
"""
249+
if not self._use_global_cursor and self.limit_reached():
250+
logger.info(
251+
f"Exceeded the 'SWITCH_TO_GLOBAL_LIMIT' of {self.SWITCH_TO_GLOBAL_LIMIT}. "
252+
f"Switching to global cursor for {self._stream_name}."
253+
)
254+
self._use_global_cursor = True
255+
235256
with self._lock:
236257
while len(self._cursor_per_partition) > self.DEFAULT_MAX_PARTITIONS_NUMBER - 1:
237-
self._over_limit += 1
238258
# Try removing finished partitions first
239259
for partition_key in list(self._cursor_per_partition.keys()):
240260
if (
@@ -245,7 +265,7 @@ def _ensure_partition_limit(self) -> None:
245265
partition_key
246266
) # Remove the oldest partition
247267
logger.warning(
248-
f"The maximum number of partitions has been reached. Dropping the oldest finished partition: {oldest_partition}. Over limit: {self._over_limit}."
268+
f"The maximum number of partitions has been reached. Dropping the oldest finished partition: {oldest_partition}. Over limit: {self._number_of_partitions - self.DEFAULT_MAX_PARTITIONS_NUMBER}."
249269
)
250270
break
251271
else:
@@ -254,7 +274,7 @@ def _ensure_partition_limit(self) -> None:
254274
1
255275
] # Remove the oldest partition
256276
logger.warning(
257-
f"The maximum number of partitions has been reached. Dropping the oldest partition: {oldest_partition}. Over limit: {self._over_limit}."
277+
f"The maximum number of partitions has been reached. Dropping the oldest partition: {oldest_partition}. Over limit: {self._number_of_partitions - self.DEFAULT_MAX_PARTITIONS_NUMBER}."
258278
)
259279

260280
def _set_initial_state(self, stream_state: StreamState) -> None:
@@ -314,6 +334,7 @@ def _set_initial_state(self, stream_state: StreamState) -> None:
314334
self._lookback_window = int(stream_state.get("lookback_window", 0))
315335

316336
for state in stream_state.get(self._PERPARTITION_STATE_KEY, []):
337+
self._number_of_partitions += 1
317338
self._cursor_per_partition[self._to_partition_key(state["partition"])] = (
318339
self._create_cursor(state["cursor"])
319340
)
@@ -354,16 +375,26 @@ def _set_global_state(self, stream_state: Mapping[str, Any]) -> None:
354375
self._new_global_cursor = deepcopy(fixed_global_state)
355376

356377
def observe(self, record: Record) -> None:
357-
if not self._use_global_cursor and self.limit_reached():
358-
self._use_global_cursor = True
359-
360378
if not record.associated_slice:
361379
raise ValueError(
362380
"Invalid state as stream slices that are emitted should refer to an existing cursor"
363381
)
364-
self._cursor_per_partition[
365-
self._to_partition_key(record.associated_slice.partition)
366-
].observe(record)
382+
383+
record_cursor = self._connector_state_converter.output_format(
384+
self._connector_state_converter.parse_value(self._cursor_field.extract_value(record))
385+
)
386+
self._update_global_cursor(record_cursor)
387+
if not self._use_global_cursor:
388+
self._cursor_per_partition[
389+
self._to_partition_key(record.associated_slice.partition)
390+
].observe(record)
391+
392+
def _update_global_cursor(self, value: Any) -> None:
393+
if (
394+
self._new_global_cursor is None
395+
or self._new_global_cursor[self.cursor_field.cursor_field_key] < value
396+
):
397+
self._new_global_cursor = {self.cursor_field.cursor_field_key: copy.deepcopy(value)}
367398

368399
def _to_partition_key(self, partition: Mapping[str, Any]) -> str:
369400
return self._partition_serializer.to_partition_key(partition)
@@ -397,4 +428,4 @@ def _get_cursor(self, record: Record) -> ConcurrentCursor:
397428
return cursor
398429

399430
def limit_reached(self) -> bool:
400-
return self._over_limit > self.DEFAULT_MAX_PARTITIONS_NUMBER
431+
return self._number_of_partitions > self.SWITCH_TO_GLOBAL_LIMIT

Diff for: unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py

+58-8
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from copy import deepcopy
44
from datetime import datetime, timedelta
55
from typing import Any, List, Mapping, MutableMapping, Optional, Union
6+
from unittest.mock import MagicMock, patch
67
from urllib.parse import unquote
78

89
import pytest
@@ -18,6 +19,7 @@
1819
from airbyte_cdk.sources.declarative.concurrent_declarative_source import (
1920
ConcurrentDeclarativeSource,
2021
)
22+
from airbyte_cdk.sources.declarative.incremental import ConcurrentPerPartitionCursor
2123
from airbyte_cdk.test.catalog_builder import CatalogBuilder, ConfiguredAirbyteStreamBuilder
2224
from airbyte_cdk.test.entrypoint_wrapper import EntrypointOutput, read
2325

@@ -1181,14 +1183,18 @@ def test_incremental_parent_state(
11811183
initial_state,
11821184
expected_state,
11831185
):
1184-
run_incremental_parent_state_test(
1185-
manifest,
1186-
mock_requests,
1187-
expected_records,
1188-
num_intermediate_states,
1189-
initial_state,
1190-
[expected_state],
1191-
)
1186+
# Patch `_throttle_state_message` so it always returns a float (indicating "no throttle")
1187+
with patch.object(
1188+
ConcurrentPerPartitionCursor, "_throttle_state_message", return_value=9999999.0
1189+
):
1190+
run_incremental_parent_state_test(
1191+
manifest,
1192+
mock_requests,
1193+
expected_records,
1194+
num_intermediate_states,
1195+
initial_state,
1196+
[expected_state],
1197+
)
11921198

11931199

11941200
STATE_MIGRATION_EXPECTED_STATE = {
@@ -2967,3 +2973,47 @@ def test_incremental_substream_request_options_provider(
29672973
expected_records,
29682974
expected_state,
29692975
)
2976+
2977+
2978+
def test_state_throttling(mocker):
2979+
"""
2980+
Verifies that _emit_state_message does not emit a new state if less than 60s
2981+
have passed since last emission, and does emit once 60s or more have passed.
2982+
"""
2983+
cursor = ConcurrentPerPartitionCursor(
2984+
cursor_factory=MagicMock(),
2985+
partition_router=MagicMock(),
2986+
stream_name="test_stream",
2987+
stream_namespace=None,
2988+
stream_state={},
2989+
message_repository=MagicMock(),
2990+
connector_state_manager=MagicMock(),
2991+
connector_state_converter=MagicMock(),
2992+
cursor_field=MagicMock(),
2993+
)
2994+
2995+
mock_connector_manager = cursor._connector_state_manager
2996+
mock_repo = cursor._message_repository
2997+
2998+
# Set the last emission time to "0" so we can control offset from that
2999+
cursor._last_emission_time = 0
3000+
3001+
mock_time = mocker.patch("time.time")
3002+
3003+
# First attempt: only 10 seconds passed => NO emission
3004+
mock_time.return_value = 10
3005+
cursor._emit_state_message()
3006+
mock_connector_manager.update_state_for_stream.assert_not_called()
3007+
mock_repo.emit_message.assert_not_called()
3008+
3009+
# Second attempt: 30 seconds passed => still NO emission
3010+
mock_time.return_value = 30
3011+
cursor._emit_state_message()
3012+
mock_connector_manager.update_state_for_stream.assert_not_called()
3013+
mock_repo.emit_message.assert_not_called()
3014+
3015+
# Advance time: 70 seconds => exceed 60s => MUST emit
3016+
mock_time.return_value = 70
3017+
cursor._emit_state_message()
3018+
mock_connector_manager.update_state_for_stream.assert_called_once()
3019+
mock_repo.emit_message.assert_called_once()

0 commit comments

Comments
 (0)