Skip to content

Commit

Permalink
chore(refactor): Remove Partition.close (#32)
Browse files Browse the repository at this point in the history
Co-authored-by: Aaron ("AJ") Steers <[email protected]>
Co-authored-by: octavia-squidington-iii <[email protected]>
  • Loading branch information
3 people authored Nov 14, 2024
1 parent 39786d2 commit ab7ab68
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def on_partition_complete_sentinel(

try:
if sentinel.is_successful:
partition.close()
stream = self._stream_name_to_instance[partition.stream_name()]
stream.cursor.close_partition(partition)
except Exception as exception:
self._flag_exception(partition.stream_name(), exception)
yield AirbyteTracedException.from_exception(
Expand Down
11 changes: 0 additions & 11 deletions airbyte_cdk/sources/file_based/stream/concurrent/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,16 +226,13 @@ def __init__(
sync_mode: SyncMode,
cursor_field: Optional[List[str]],
state: Optional[MutableMapping[str, Any]],
cursor: "AbstractConcurrentFileBasedCursor",
):
self._stream = stream
self._slice = _slice
self._message_repository = message_repository
self._sync_mode = sync_mode
self._cursor_field = cursor_field
self._state = state
self._cursor = cursor
self._is_closed = False

def read(self) -> Iterable[Record]:
try:
Expand Down Expand Up @@ -289,13 +286,6 @@ def to_slice(self) -> Optional[Mapping[str, Any]]:
file = self._slice["files"][0]
return {"files": [file]}

def close(self) -> None:
self._cursor.close_partition(self)
self._is_closed = True

def is_closed(self) -> bool:
return self._is_closed

def __hash__(self) -> int:
if self._slice:
# Convert the slice to a string so that it can be hashed
Expand Down Expand Up @@ -352,7 +342,6 @@ def generate(self) -> Iterable[FileBasedStreamPartition]:
self._sync_mode,
self._cursor_field,
self._state,
self._cursor,
)
)
self._cursor.set_pending_partitions(pending_partitions)
Expand Down
15 changes: 0 additions & 15 deletions airbyte_cdk/sources/streams/concurrent/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def create_from_stream(
else SyncMode.incremental,
[cursor_field] if cursor_field is not None else None,
state,
cursor,
),
name=stream.name,
namespace=stream.namespace,
Expand Down Expand Up @@ -259,7 +258,6 @@ def __init__(
sync_mode: SyncMode,
cursor_field: Optional[List[str]],
state: Optional[MutableMapping[str, Any]],
cursor: Cursor,
):
"""
:param stream: The stream to delegate to
Expand All @@ -272,8 +270,6 @@ def __init__(
self._sync_mode = sync_mode
self._cursor_field = cursor_field
self._state = state
self._cursor = cursor
self._is_closed = False

def read(self) -> Iterable[Record]:
"""
Expand Down Expand Up @@ -323,13 +319,6 @@ def __hash__(self) -> int:
def stream_name(self) -> str:
return self._stream.name

def close(self) -> None:
self._cursor.close_partition(self)
self._is_closed = True

def is_closed(self) -> bool:
return self._is_closed

def __repr__(self) -> str:
return f"StreamPartition({self._stream.name}, {self._slice})"

Expand All @@ -349,7 +338,6 @@ def __init__(
sync_mode: SyncMode,
cursor_field: Optional[List[str]],
state: Optional[MutableMapping[str, Any]],
cursor: Cursor,
):
"""
:param stream: The stream to delegate to
Expand All @@ -360,7 +348,6 @@ def __init__(
self._sync_mode = sync_mode
self._cursor_field = cursor_field
self._state = state
self._cursor = cursor

def generate(self) -> Iterable[Partition]:
for s in self._stream.stream_slices(
Expand All @@ -373,7 +360,6 @@ def generate(self) -> Iterable[Partition]:
self._sync_mode,
self._cursor_field,
self._state,
self._cursor,
)


Expand Down Expand Up @@ -451,7 +437,6 @@ def generate(self) -> Iterable[Partition]:
self._sync_mode,
self._cursor_field,
self._state,
self._cursor,
)


Expand Down
15 changes: 0 additions & 15 deletions airbyte_cdk/sources/streams/concurrent/partitions/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,6 @@ def stream_name(self) -> str:
"""
pass

@abstractmethod
def close(self) -> None:
"""
Closes the partition.
"""
pass

@abstractmethod
def is_closed(self) -> bool:
"""
Returns whether the partition is closed.
:return:
"""
pass

@abstractmethod
def __hash__(self) -> int:
"""
Expand Down
15 changes: 12 additions & 3 deletions unit_tests/sources/file_based/stream/concurrent/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,12 @@ def test_file_based_stream_partition(transformer, expected_records):
cursor_field = None
state = None
partition = FileBasedStreamPartition(
stream, _slice, message_repository, sync_mode, cursor_field, state, _ANY_CURSOR
stream,
_slice,
message_repository,
sync_mode,
cursor_field,
state,
)

a_log_message = AirbyteMessage(
Expand Down Expand Up @@ -168,7 +173,6 @@ def test_file_based_stream_partition_raising_exception(exception_type, expected_
_ANY_SYNC_MODE,
_ANY_CURSOR_FIELD,
_ANY_STATE,
_ANY_CURSOR,
)

stream.read_records.side_effect = Exception()
Expand Down Expand Up @@ -204,7 +208,12 @@ def test_file_based_stream_partition_hash(_slice, expected_hash):
stream = Mock()
stream.name = "stream"
partition = FileBasedStreamPartition(
stream, _slice, Mock(), _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, _ANY_CURSOR
stream,
_slice,
Mock(),
_ANY_SYNC_MODE,
_ANY_CURSOR_FIELD,
_ANY_STATE,
)

_hash = partition.__hash__()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ def test_add_file(
SyncMode.full_refresh,
FileBasedConcurrentCursor.CURSOR_FIELD,
initial_state,
cursor,
)
for uri, timestamp in pending_files
]
Expand Down
9 changes: 3 additions & 6 deletions unit_tests/sources/streams/concurrent/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_stream_partition_generator(sync_mode):
stream.stream_slices.return_value = stream_slices

partition_generator = StreamPartitionGenerator(
stream, message_repository, _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, _ANY_CURSOR
stream, message_repository, _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE
)

partitions = list(partition_generator.generate())
Expand Down Expand Up @@ -115,9 +115,7 @@ def test_stream_partition(transformer, expected_records):
sync_mode = SyncMode.full_refresh
cursor_field = None
state = None
partition = StreamPartition(
stream, _slice, message_repository, sync_mode, cursor_field, state, _ANY_CURSOR
)
partition = StreamPartition(stream, _slice, message_repository, sync_mode, cursor_field, state)

a_log_message = AirbyteMessage(
type=MessageType.LOG,
Expand Down Expand Up @@ -162,7 +160,6 @@ def test_stream_partition_raising_exception(exception_type, expected_display_mes
_ANY_SYNC_MODE,
_ANY_CURSOR_FIELD,
_ANY_STATE,
_ANY_CURSOR,
)

stream.read_records.side_effect = Exception()
Expand All @@ -188,7 +185,7 @@ def test_stream_partition_hash(_slice, expected_hash):
stream = Mock()
stream.name = "stream"
partition = StreamPartition(
stream, _slice, Mock(), _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, _ANY_CURSOR
stream, _slice, Mock(), _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE
)

_hash = partition.__hash__()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def test_handle_on_partition_complete_sentinel_with_messages_from_repository(sel
]
assert messages == expected_messages

partition.close.assert_called_once()
self._stream.cursor.close_partition.assert_called_once()

@freezegun.freeze_time("2020-01-01T00:00:00")
def test_handle_on_partition_complete_sentinel_yields_status_message_if_the_stream_is_done(
Expand Down Expand Up @@ -298,14 +298,14 @@ def test_handle_on_partition_complete_sentinel_yields_status_message_if_the_stre
)
]
assert messages == expected_messages
self._a_closed_partition.close.assert_called_once()
self._another_stream.cursor.close_partition.assert_called_once()

@freezegun.freeze_time("2020-01-01T00:00:00")
def test_given_exception_on_partition_complete_sentinel_then_yield_error_trace_message_and_stream_is_incomplete(
self,
) -> None:
self._a_closed_partition.stream_name.return_value = self._stream.name
self._a_closed_partition.close.side_effect = ValueError
self._stream.cursor.close_partition.side_effect = ValueError

handler = ConcurrentReadProcessor(
[self._stream],
Expand Down Expand Up @@ -375,7 +375,7 @@ def test_handle_on_partition_complete_sentinel_yields_no_status_message_if_the_s

expected_messages = []
assert messages == expected_messages
partition.close.assert_called_once()
self._stream.cursor.close_partition.assert_called_once()

@freezegun.freeze_time("2020-01-01T00:00:00")
def test_on_record_no_status_message_no_repository_messge(self):
Expand Down Expand Up @@ -733,7 +733,7 @@ def test_given_partition_completion_is_not_success_then_do_not_close_partition(s
)
)

assert self._an_open_partition.close.call_count == 0
assert self._stream.cursor.close_partition.call_count == 0

def test_is_done_is_false_if_there_are_any_instances_to_read_from(self):
stream_instances_to_read_from = [self._stream]
Expand Down

0 comments on commit ab7ab68

Please sign in to comment.