Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(concurrent perpartition cursor): Refactor ConcurrentPerPartitionCursor #331

Merged
merged 22 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import copy
import logging
import threading
import time
from collections import OrderedDict
from copy import deepcopy
from datetime import timedelta
Expand Down Expand Up @@ -58,7 +59,8 @@ class ConcurrentPerPartitionCursor(Cursor):
CurrentPerPartitionCursor expects the state of the ConcurrentCursor to follow the format {cursor_field: cursor_value}.
"""

DEFAULT_MAX_PARTITIONS_NUMBER = 10000
DEFAULT_MAX_PARTITIONS_NUMBER = 25_000
SWITCH_TO_GLOBAL_LIMIT = 10_000
_NO_STATE: Mapping[str, Any] = {}
_NO_CURSOR_STATE: Mapping[str, Any] = {}
_GLOBAL_STATE_KEY = "state"
Expand Down Expand Up @@ -99,9 +101,11 @@ def __init__(
self._new_global_cursor: Optional[StreamState] = None
self._lookback_window: int = 0
self._parent_state: Optional[StreamState] = None
self._over_limit: int = 0
self._number_of_partitions: int = 0
maxi297 marked this conversation as resolved.
Show resolved Hide resolved
self._use_global_cursor: bool = False
self._partition_serializer = PerPartitionKeySerializer()
# Track the last time a state message was emitted
self._last_emission_time: float = 0.0

self._set_initial_state(stream_state)

Expand Down Expand Up @@ -141,21 +145,16 @@ def close_partition(self, partition: Partition) -> None:
raise ValueError("stream_slice cannot be None")

partition_key = self._to_partition_key(stream_slice.partition)
maxi297 marked this conversation as resolved.
Show resolved Hide resolved
self._cursor_per_partition[partition_key].close_partition(partition=partition)
with self._lock:
self._semaphore_per_partition[partition_key].acquire()
cursor = self._cursor_per_partition[partition_key]
if (
partition_key in self._finished_partitions
and self._semaphore_per_partition[partition_key]._value == 0
):
if not self._use_global_cursor:
self._cursor_per_partition[partition_key].close_partition(partition=partition)
cursor = self._cursor_per_partition[partition_key]
if (
self._new_global_cursor is None
or self._new_global_cursor[self.cursor_field.cursor_field_key]
< cursor.state[self.cursor_field.cursor_field_key]
partition_key in self._finished_partitions
and self._semaphore_per_partition[partition_key]._value == 0
):
self._new_global_cursor = copy.deepcopy(cursor.state)
if not self._use_global_cursor:
self._update_global_cursor(cursor.state[self.cursor_field.cursor_field_key])
self._emit_state_message()

def ensure_at_least_one_state_emitted(self) -> None:
Expand All @@ -169,9 +168,23 @@ def ensure_at_least_one_state_emitted(self) -> None:
self._global_cursor = self._new_global_cursor
self._lookback_window = self._timer.finish()
self._parent_state = self._partition_router.get_stream_state()
self._emit_state_message()
self._emit_state_message(throttle=False)

def _emit_state_message(self) -> None:
def _throttle_state_message(self) -> Optional[float]:
"""
Throttles the state message emission to once every 60 seconds.
"""
current_time = time.time()
if current_time - self._last_emission_time <= 60:
return None
return current_time

def _emit_state_message(self, throttle: bool = True) -> None:
if throttle:
current_time = self._throttle_state_message()
if current_time is None:
return
self._last_emission_time = current_time
self._connector_state_manager.update_state_for_stream(
self._stream_name,
self._stream_namespace,
Expand Down Expand Up @@ -202,6 +215,7 @@ def _generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[St
self._lookback_window if self._global_cursor else 0,
)
with self._lock:
self._number_of_partitions += 1
self._cursor_per_partition[self._to_partition_key(partition.partition)] = cursor
self._semaphore_per_partition[self._to_partition_key(partition.partition)] = (
threading.Semaphore(0)
Expand Down Expand Up @@ -232,9 +246,15 @@ def _ensure_partition_limit(self) -> None:
- Logs a warning each time a partition is removed, indicating whether it was finished
or removed due to being the oldest.
"""
if not self._use_global_cursor and self.limit_reached():
logger.info(
f"Exceeded the 'SWITCH_TO_GLOBAL_LIMIT' of {self.SWITCH_TO_GLOBAL_LIMIT}. "
f"Switching to global cursor for {self._stream_name}."
)
self._use_global_cursor = True

with self._lock:
while len(self._cursor_per_partition) > self.DEFAULT_MAX_PARTITIONS_NUMBER - 1:
self._over_limit += 1
# Try removing finished partitions first
for partition_key in list(self._cursor_per_partition.keys()):
if (
Expand All @@ -245,7 +265,7 @@ def _ensure_partition_limit(self) -> None:
partition_key
) # Remove the oldest partition
logger.warning(
f"The maximum number of partitions has been reached. Dropping the oldest finished partition: {oldest_partition}. Over limit: {self._over_limit}."
f"The maximum number of partitions has been reached. Dropping the oldest finished partition: {oldest_partition}. Over limit: {self._number_of_partitions - self.DEFAULT_MAX_PARTITIONS_NUMBER}."
)
break
else:
Expand All @@ -254,7 +274,7 @@ def _ensure_partition_limit(self) -> None:
1
] # Remove the oldest partition
logger.warning(
f"The maximum number of partitions has been reached. Dropping the oldest partition: {oldest_partition}. Over limit: {self._over_limit}."
f"The maximum number of partitions has been reached. Dropping the oldest partition: {oldest_partition}. Over limit: {self._number_of_partitions - self.DEFAULT_MAX_PARTITIONS_NUMBER}."
)

def _set_initial_state(self, stream_state: StreamState) -> None:
Expand Down Expand Up @@ -314,6 +334,7 @@ def _set_initial_state(self, stream_state: StreamState) -> None:
self._lookback_window = int(stream_state.get("lookback_window", 0))

for state in stream_state.get(self._PERPARTITION_STATE_KEY, []):
self._number_of_partitions += 1
self._cursor_per_partition[self._to_partition_key(state["partition"])] = (
self._create_cursor(state["cursor"])
)
Expand Down Expand Up @@ -354,16 +375,26 @@ def _set_global_state(self, stream_state: Mapping[str, Any]) -> None:
self._new_global_cursor = deepcopy(fixed_global_state)

def observe(self, record: Record) -> None:
if not self._use_global_cursor and self.limit_reached():
self._use_global_cursor = True

if not record.associated_slice:
raise ValueError(
"Invalid state as stream slices that are emitted should refer to an existing cursor"
)
self._cursor_per_partition[
self._to_partition_key(record.associated_slice.partition)
].observe(record)

record_cursor = self._connector_state_converter.output_format(
self._connector_state_converter.parse_value(self._cursor_field.extract_value(record))
)
self._update_global_cursor(record_cursor)
if not self._use_global_cursor:
self._cursor_per_partition[
self._to_partition_key(record.associated_slice.partition)
].observe(record)

tolik0 marked this conversation as resolved.
Show resolved Hide resolved
def _update_global_cursor(self, value: Any) -> None:
if (
self._new_global_cursor is None
or self._new_global_cursor[self.cursor_field.cursor_field_key] < value
):
self._new_global_cursor = {self.cursor_field.cursor_field_key: copy.deepcopy(value)}

def _to_partition_key(self, partition: Mapping[str, Any]) -> str:
return self._partition_serializer.to_partition_key(partition)
Expand Down Expand Up @@ -397,4 +428,4 @@ def _get_cursor(self, record: Record) -> ConcurrentCursor:
return cursor

def limit_reached(self) -> bool:
return self._over_limit > self.DEFAULT_MAX_PARTITIONS_NUMBER
return self._number_of_partitions > self.SWITCH_TO_GLOBAL_LIMIT
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from copy import deepcopy
from datetime import datetime, timedelta
from typing import Any, List, Mapping, MutableMapping, Optional, Union
from unittest.mock import MagicMock, patch
from urllib.parse import unquote

import pytest
Expand All @@ -18,6 +19,7 @@
from airbyte_cdk.sources.declarative.concurrent_declarative_source import (
ConcurrentDeclarativeSource,
)
from airbyte_cdk.sources.declarative.incremental import ConcurrentPerPartitionCursor
from airbyte_cdk.test.catalog_builder import CatalogBuilder, ConfiguredAirbyteStreamBuilder
from airbyte_cdk.test.entrypoint_wrapper import EntrypointOutput, read

Expand Down Expand Up @@ -1181,14 +1183,18 @@ def test_incremental_parent_state(
initial_state,
expected_state,
):
run_incremental_parent_state_test(
manifest,
mock_requests,
expected_records,
num_intermediate_states,
initial_state,
[expected_state],
)
# Patch `_throttle_state_message` so it always returns a float (indicating "no throttle")
with patch.object(
ConcurrentPerPartitionCursor, "_throttle_state_message", return_value=9999999.0
):
run_incremental_parent_state_test(
manifest,
mock_requests,
expected_records,
num_intermediate_states,
initial_state,
[expected_state],
)


STATE_MIGRATION_EXPECTED_STATE = {
Expand Down Expand Up @@ -2967,3 +2973,47 @@ def test_incremental_substream_request_options_provider(
expected_records,
expected_state,
)


def test_state_throttling(mocker):
"""
Verifies that _emit_state_message does not emit a new state if less than 60s
have passed since last emission, and does emit once 60s or more have passed.
"""
cursor = ConcurrentPerPartitionCursor(
cursor_factory=MagicMock(),
partition_router=MagicMock(),
stream_name="test_stream",
stream_namespace=None,
stream_state={},
message_repository=MagicMock(),
connector_state_manager=MagicMock(),
connector_state_converter=MagicMock(),
cursor_field=MagicMock(),
)

mock_connector_manager = cursor._connector_state_manager
mock_repo = cursor._message_repository

# Set the last emission time to "0" so we can control offset from that
cursor._last_emission_time = 0

mock_time = mocker.patch("time.time")

# First attempt: only 10 seconds passed => NO emission
mock_time.return_value = 10
cursor._emit_state_message()
mock_connector_manager.update_state_for_stream.assert_not_called()
mock_repo.emit_message.assert_not_called()

# Second attempt: 30 seconds passed => still NO emission
mock_time.return_value = 30
cursor._emit_state_message()
mock_connector_manager.update_state_for_stream.assert_not_called()
mock_repo.emit_message.assert_not_called()

# Advance time: 70 seconds => exceed 60s => MUST emit
mock_time.return_value = 70
cursor._emit_state_message()
mock_connector_manager.update_state_for_stream.assert_called_once()
mock_repo.emit_message.assert_called_once()
Loading