Skip to content

Commit

Permalink
[SDK] Ability to cancel running partition call. (#1211)
Browse files Browse the repository at this point in the history
* Ability to cancel running partition call.

* Add to __init__.py

* Add test; close explicitly.
  • Loading branch information
alexaryn authored Mar 5, 2025
1 parent 2ccfc58 commit 7e472da
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 0 deletions.
2 changes: 2 additions & 0 deletions lib/aryn-sdk/aryn_sdk/partition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
tables_to_pandas,
table_elem_to_dataframe,
convert_image_element,
BoolFlag,
PartitionError,
PartitionTaskError,
PartitionTaskNotFoundError,
Expand All @@ -19,6 +20,7 @@
"tables_to_pandas",
"draw_with_boxes",
"convert_image_element",
"BoolFlag",
"PartitionError",
"PartitionTaskError",
"PartitionTaskNotFoundError",
Expand Down
28 changes: 28 additions & 0 deletions lib/aryn-sdk/aryn_sdk/partition/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,23 @@ def __init__(self, message: str, status_code: int) -> None:
self.status_code = status_code


class BoolFlag:
"""
A boxed boolean that can be mutated and passed around by reference.
"""

__slots__ = "val"

def __init__(self, val: bool) -> None:
self.val = val

def set(self, val: bool) -> None:
self.val = val

def get(self) -> bool:
return self.val


def partition_file(
file: Union[BinaryIO, str, PathLike],
*,
Expand All @@ -64,6 +81,7 @@ def partition_file(
output_format: Optional[str] = None,
output_label_options: Optional[dict[str, Any]] = None,
trace_id: Optional[str] = None,
cancel_flag: Optional[BoolFlag] = None,
) -> dict:
"""
Sends file to Aryn DocParse and returns a dict of its document structure and text
Expand Down Expand Up @@ -136,6 +154,7 @@ def partition_file(
}
default: None (no element is promoted to "Title")
trace_id: for internal use
cancel_flag: way to interrupt partitioning from the outside
Returns:
Expand Down Expand Up @@ -177,6 +196,7 @@ def partition_file(
output_format=output_format,
output_label_options=output_label_options,
trace_id=trace_id,
cancel_flag=cancel_flag,
)


Expand All @@ -202,6 +222,7 @@ def _partition_file_wrapper(
output_label_options: Optional[dict[str, Any]] = None,
webhook_url: Optional[str] = None,
trace_id: Optional[str] = None,
cancel_flag: Optional[BoolFlag] = None,
):
"""Do not call this function directly. Use partition_file or partition_file_async_submit instead."""

Expand Down Expand Up @@ -231,6 +252,7 @@ def _partition_file_wrapper(
output_format=output_format,
output_label_options=output_label_options,
trace_id=trace_id,
cancel_flag=cancel_flag,
webhook_url=webhook_url,
)
finally:
Expand Down Expand Up @@ -259,6 +281,7 @@ def _partition_file_inner(
output_format: Optional[str] = None,
output_label_options: Optional[dict[str, Any]] = None,
trace_id: Optional[str] = None,
cancel_flag: Optional[BoolFlag] = None,
webhook_url: Optional[str] = None,
):
"""Do not call this function directly. Use partition_file or partition_file_async_submit instead."""
Expand Down Expand Up @@ -303,6 +326,11 @@ def _partition_file_inner(
partial_line = []
in_bulk = False
for part in resp.iter_content(None):
# A big doc could take a while; we may be asked to bail out early
if cancel_flag and cancel_flag.get():
resp.close()
break

if not part:
continue

Expand Down
7 changes: 7 additions & 0 deletions lib/aryn-sdk/aryn_sdk/test/test_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
partition_file_async_list,
PartitionError,
PartitionTaskNotFoundError,
BoolFlag,
convert_image_element,
tables_to_pandas,
)
Expand Down Expand Up @@ -157,6 +158,12 @@ def test_partition_file_auto_rotation():
assert actual["elements"] == expected["elements"]


def test_partition_file_cancel():
flag = BoolFlag(True)
with pytest.raises(Exception):
partition_file(RESOURCE_DIR / "pdfs" / "SPsort.pdf", cancel_flag=flag)


def test_data_to_pandas():
with open(RESOURCE_DIR / "json" / "3m_output_ocr_table.json", "r") as f:
data = json.load(f)
Expand Down

0 comments on commit 7e472da

Please sign in to comment.