Skip to content

fix(concurrent-perpartition-cursor): Fix memory issues #568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections import OrderedDict
from copy import deepcopy
from datetime import timedelta
from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional
from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional

from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
from airbyte_cdk.sources.declarative.incremental.global_substream_cursor import (
Expand Down Expand Up @@ -66,8 +66,8 @@ class ConcurrentPerPartitionCursor(Cursor):
_GLOBAL_STATE_KEY = "state"
_PERPARTITION_STATE_KEY = "states"
_IS_PARTITION_DUPLICATION_LOGGED = False
_KEY = 0
_VALUE = 1
_PARENT_STATE = 0
_GENERATION_SEQUENCE = 1

def __init__(
self,
Expand Down Expand Up @@ -99,19 +99,29 @@ def __init__(
self._semaphore_per_partition: OrderedDict[str, threading.Semaphore] = OrderedDict()

# Parent-state tracking: store each partition’s parent state in creation order
self._partition_parent_state_map: OrderedDict[str, Mapping[str, Any]] = OrderedDict()
self._partition_parent_state_map: OrderedDict[str, tuple[Mapping[str, Any], int]] = (
OrderedDict()
)
self._parent_state: Optional[StreamState] = None

# Tracks when the last slice for partition is emitted
self._partitions_done_generating_stream_slices: set[str] = set()
# Used to track the index of partitions that are not closed yet
self._processing_partitions_indexes: List[int] = list()
self._generated_partitions_count: int = 0
# Dictionary to map partition keys to their index
self._partition_key_to_index: dict[str, int] = {}

self._finished_partitions: set[str] = set()
self._lock = threading.Lock()
self._timer = Timer()
self._new_global_cursor: Optional[StreamState] = None
self._lookback_window: int = 0
self._parent_state: Optional[StreamState] = None
self._new_global_cursor: Optional[StreamState] = None
self._number_of_partitions: int = 0
self._use_global_cursor: bool = use_global_cursor
self._partition_serializer = PerPartitionKeySerializer()

# Track the last time a state message was emitted
self._last_emission_time: float = 0.0
self._timer = Timer()

self._set_initial_state(stream_state)

Expand Down Expand Up @@ -157,60 +167,37 @@ def close_partition(self, partition: Partition) -> None:
self._cursor_per_partition[partition_key].close_partition(partition=partition)
cursor = self._cursor_per_partition[partition_key]
if (
partition_key in self._finished_partitions
partition_key in self._partitions_done_generating_stream_slices
and self._semaphore_per_partition[partition_key]._value == 0
):
self._update_global_cursor(cursor.state[self.cursor_field.cursor_field_key])

# Clean up the partition if it is fully processed
self._cleanup_if_done(partition_key)

self._check_and_update_parent_state()

self._emit_state_message()

def _check_and_update_parent_state(self) -> None:
"""
Pop the leftmost partition state from _partition_parent_state_map only if
*all partitions* up to (and including) that partition key in _semaphore_per_partition
are fully finished (i.e. in _finished_partitions and semaphore._value == 0).
Additionally, delete finished semaphores with a value of 0 to free up memory,
as they are only needed to track errors and completion status.
"""
last_closed_state = None

while self._partition_parent_state_map:
# Look at the earliest partition key in creation order
earliest_key = next(iter(self._partition_parent_state_map))

# Verify ALL partitions from the left up to earliest_key are finished
all_left_finished = True
for p_key, sem in list(
self._semaphore_per_partition.items()
): # Use list to allow modification during iteration
# If any earlier partition is still not finished, we must stop
if p_key not in self._finished_partitions or sem._value != 0:
all_left_finished = False
break
# Once we've reached earliest_key in the semaphore order, we can stop checking
if p_key == earliest_key:
break

# If the partitions up to earliest_key are not all finished, break the while-loop
if not all_left_finished:
break
earliest_key, (candidate_state, candidate_seq) = next(
iter(self._partition_parent_state_map.items())
)

# Pop the leftmost entry from parent-state map
_, closed_parent_state = self._partition_parent_state_map.popitem(last=False)
last_closed_state = closed_parent_state
# if any partition that started <= candidate_seq is still open, we must wait
if (
self._processing_partitions_indexes
and self._processing_partitions_indexes[0] <= candidate_seq
):
break

# Clean up finished semaphores with value 0 up to and including earliest_key
for p_key in list(self._semaphore_per_partition.keys()):
sem = self._semaphore_per_partition[p_key]
if p_key in self._finished_partitions and sem._value == 0:
del self._semaphore_per_partition[p_key]
logger.debug(f"Deleted finished semaphore for partition {p_key} with value 0")
if p_key == earliest_key:
break
# safe to pop
self._partition_parent_state_map.popitem(last=False)
last_closed_state = candidate_state

# Update _parent_state if we popped at least one partition
if last_closed_state is not None:
self._parent_state = last_closed_state

Expand Down Expand Up @@ -289,29 +276,39 @@ def _generate_slices_from_partition(
if not self._IS_PARTITION_DUPLICATION_LOGGED:
logger.warning(f"Partition duplication detected for stream {self._stream_name}")
self._IS_PARTITION_DUPLICATION_LOGGED = True
return
else:
self._semaphore_per_partition[partition_key] = threading.Semaphore(0)

with self._lock:
seq = self._generated_partitions_count
self._generated_partitions_count += 1
self._processing_partitions_indexes.append(seq)
self._partition_key_to_index[partition_key] = seq

if (
len(self._partition_parent_state_map) == 0
or self._partition_parent_state_map[
next(reversed(self._partition_parent_state_map))
]
][self._PARENT_STATE]
!= parent_state
):
self._partition_parent_state_map[partition_key] = deepcopy(parent_state)
self._partition_parent_state_map[partition_key] = (deepcopy(parent_state), seq)

for cursor_slice, is_last_slice, _ in iterate_with_last_flag_and_state(
cursor.stream_slices(),
lambda: None,
):
self._semaphore_per_partition[partition_key].release()
if is_last_slice:
self._finished_partitions.add(partition_key)
yield StreamSlice(
partition=partition, cursor_slice=cursor_slice, extra_fields=partition.extra_fields
)
try:
for cursor_slice, is_last_slice, _ in iterate_with_last_flag_and_state(
cursor.stream_slices(),
lambda: None,
):
self._semaphore_per_partition[partition_key].release()
if is_last_slice:
self._partitions_done_generating_stream_slices.add(partition_key)
yield StreamSlice(
partition=partition, cursor_slice=cursor_slice, extra_fields=partition.extra_fields
)
finally:
del cursor
del partition

def _ensure_partition_limit(self) -> None:
"""
Expand All @@ -338,10 +335,7 @@ def _ensure_partition_limit(self) -> None:
while len(self._cursor_per_partition) > self.DEFAULT_MAX_PARTITIONS_NUMBER - 1:
# Try removing finished partitions first
for partition_key in list(self._cursor_per_partition.keys()):
if partition_key in self._finished_partitions and (
partition_key not in self._semaphore_per_partition
or self._semaphore_per_partition[partition_key]._value == 0
):
if partition_key not in self._partition_key_to_index:
oldest_partition = self._cursor_per_partition.pop(
partition_key
) # Remove the oldest partition
Expand Down Expand Up @@ -474,6 +468,25 @@ def _update_global_cursor(self, value: Any) -> None:
):
self._new_global_cursor = {self.cursor_field.cursor_field_key: copy.deepcopy(value)}

def _cleanup_if_done(self, partition_key: str) -> None:
"""
Free every in-memory structure that belonged to a completed partition:
cursor, semaphore, flag inside `_finished_partitions`
"""
if not (
partition_key in self._partitions_done_generating_stream_slices
and self._semaphore_per_partition[partition_key]._value == 0
):
return

self._semaphore_per_partition.pop(partition_key, None)
self._partitions_done_generating_stream_slices.discard(partition_key)

seq = self._partition_key_to_index.pop(partition_key)
self._processing_partitions_indexes.remove(seq)

logger.debug(f"Partition {partition_key} fully processed and cleaned up.")

def _to_partition_key(self, partition: Mapping[str, Any]) -> str:
return self._partition_serializer.to_partition_key(partition)

Expand All @@ -483,11 +496,10 @@ def _to_dict(self, partition_key: str) -> Mapping[str, Any]:
def _create_cursor(
self, cursor_state: Any, runtime_lookback_window: int = 0
) -> ConcurrentCursor:
cursor = self._cursor_factory.create(
return self._cursor_factory.create(
stream_state=deepcopy(cursor_state),
runtime_lookback_window=timedelta(seconds=runtime_lookback_window),
)
return cursor

def should_be_synced(self, record: Record) -> bool:
return self._get_cursor(record).should_be_synced(record)
Expand All @@ -502,8 +514,7 @@ def _get_cursor(self, record: Record) -> ConcurrentCursor:
raise ValueError(
"Invalid state as stream slices that are emitted should refer to an existing cursor"
)
cursor = self._cursor_per_partition[partition_key]
return cursor
return self._cursor_per_partition[partition_key]

def limit_reached(self) -> bool:
return self._number_of_partitions > self.SWITCH_TO_GLOBAL_LIMIT
4 changes: 2 additions & 2 deletions airbyte_cdk/sources/declarative/interpolation/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,15 @@ def _eval(self, s: Optional[str], context: Mapping[str, Any]) -> Optional[str]:
# It can be returned as is
return s

@cache
# @cache
def _find_undeclared_variables(self, s: Optional[str]) -> Set[str]:
"""
Find undeclared variables and cache them
"""
ast = _ENVIRONMENT.parse(s) # type: ignore # parse is able to handle None
return meta.find_undeclared_variables(ast)

@cache
# @cache
def _compile(self, s: str) -> Template:
"""
We must cache the Jinja Template ourselves because we're using `from_string` instead of a template loader
Expand Down
Loading
Loading