diff --git a/Cargo.lock b/Cargo.lock index 9497218c9..ca99964e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -837,6 +837,16 @@ dependencies = [ "typenum", ] +[[package]] +name = "cstr" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68523903c8ae5aacfa32a0d9ae60cadeb764e1da14ee0d26b1f3089f13a54636" +dependencies = [ + "proc-macro2", + "quote", +] + [[package]] name = "csv" version = "1.3.1" @@ -1544,6 +1554,7 @@ version = "49.0.0" dependencies = [ "arrow", "async-trait", + "cstr", "datafusion", "datafusion-ffi", "datafusion-proto", diff --git a/Cargo.toml b/Cargo.toml index e51f4ddea..0be83a31d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,17 +26,34 @@ readme = "README.md" license = "Apache-2.0" edition = "2021" rust-version = "1.78" -include = ["/src", "/datafusion", "/LICENSE.txt", "build.rs", "pyproject.toml", "Cargo.toml", "Cargo.lock"] +include = [ + "/src", + "/datafusion", + "/LICENSE.txt", + "build.rs", + "pyproject.toml", + "Cargo.toml", + "Cargo.lock", +] [features] default = ["mimalloc"] -protoc = [ "datafusion-substrait/protoc" ] +protoc = ["datafusion-substrait/protoc"] substrait = ["dep:datafusion-substrait"] [dependencies] -tokio = { version = "1.45", features = ["macros", "rt", "rt-multi-thread", "sync"] } -pyo3 = { version = "0.24", features = ["extension-module", "abi3", "abi3-py39"] } -pyo3-async-runtimes = { version = "0.24", features = ["tokio-runtime"]} +tokio = { version = "1.45", features = [ + "macros", + "rt", + "rt-multi-thread", + "sync", +] } +pyo3 = { version = "0.24", features = [ + "extension-module", + "abi3", + "abi3-py39", +] } +pyo3-async-runtimes = { version = "0.24", features = ["tokio-runtime"] } pyo3-log = "0.12.4" arrow = { version = "55.1.0", features = ["pyarrow"] } datafusion = { version = "49.0.2", features = ["avro", "unicode_expressions"] } @@ -45,15 +62,23 @@ datafusion-proto = { version = "49.0.2" } datafusion-ffi = { version = "49.0.2" } prost = "0.13.1" # keep in line with `datafusion-substrait` uuid = { version = "1.18", features = ["v4"] } -mimalloc = { version = "0.1", optional = true, default-features = false, features = ["local_dynamic_tls"] } +mimalloc = { version = "0.1", optional = true, default-features = false, features = [ + "local_dynamic_tls", +] } async-trait = "0.1.89" futures = "0.3" -object_store = { version = "0.12.3", features = ["aws", "gcp", "azure", "http"] } +cstr = "0.2" +object_store = { version = "0.12.3", features = [ + "aws", + "gcp", + "azure", + "http", +] } url = "2" log = "0.4.27" [build-dependencies] -prost-types = "0.13.1" # keep in line with `datafusion-substrait` +prost-types = "0.13.1" # keep in line with `datafusion-substrait` pyo3-build-config = "0.24" [lib] diff --git a/docs/source/user-guide/dataframe/index.rst b/docs/source/user-guide/dataframe/index.rst index f69485af7..2b573ea4e 100644 --- a/docs/source/user-guide/dataframe/index.rst +++ b/docs/source/user-guide/dataframe/index.rst @@ -145,10 +145,118 @@ To materialize the results of your DataFrame operations: # Display results df.show() # Print tabular format to console - + # Count rows count = df.count() +Zero-copy streaming to Arrow-based Python libraries +--------------------------------------------------- + +DataFusion DataFrames implement the ``__arrow_c_stream__`` protocol, enabling +zero-copy, lazy streaming into Arrow-based Python libraries. With the streaming +protocol, batches are produced on demand so you can process arbitrarily large +results without out-of-memory errors. + +.. note:: + + The protocol is implementation-agnostic and works with any Python library + that understands the Arrow C streaming interface (for example, PyArrow + or other Arrow-compatible implementations). The sections below provide a + short PyArrow-specific example and general guidance for other + implementations. + +PyArrow +------- + +.. code-block:: python + + import pyarrow as pa + + # Create a PyArrow RecordBatchReader without materializing all batches + reader = pa.RecordBatchReader.from_stream(df) + for batch in reader: + ... # process each batch as it is produced + +DataFrames are also iterable, yielding :class:`datafusion.RecordBatch` +objects lazily so you can loop over results directly without importing +PyArrow: + +.. code-block:: python + + for batch in df: + ... # each batch is a ``datafusion.RecordBatch`` + +Each batch exposes ``to_pyarrow()``, allowing conversion to a PyArrow +table. ``pa.table(df)`` collects the entire DataFrame eagerly into a +PyArrow table:: + +.. code-block:: python + + import pyarrow as pa + table = pa.table(df) + +Asynchronous iteration is supported as well, allowing integration with +``asyncio`` event loops:: + +.. code-block:: python + + async for batch in df: + ... # process each batch as it is produced + +To work with the stream directly, use ``execute_stream()``, which returns a +:class:`~datafusion.RecordBatchStream`:: + +.. code-block:: python + + stream = df.execute_stream() + for batch in stream: + ... + +Execute as Stream +^^^^^^^^^^^^^^^^^ + +For finer control over streaming execution, use +:py:meth:`~datafusion.DataFrame.execute_stream` to obtain a +:py:class:`datafusion.RecordBatchStream`: + +.. code-block:: python + + stream = df.execute_stream() + for batch in stream: + ... # process each batch as it is produced + +.. tip:: + + To get a PyArrow reader instead, call + ``pa.RecordBatchReader.from_stream(df)``. + +When partition boundaries are important, +:py:meth:`~datafusion.DataFrame.execute_stream_partitioned` +returns an iterable of :py:class:`datafusion.RecordBatchStream` objects, one per +partition: + +.. code-block:: python + + for stream in df.execute_stream_partitioned(): + for batch in stream: + ... # each stream yields RecordBatches + +To process partitions concurrently, first collect the streams into a list +and then poll each one in a separate ``asyncio`` task: + +.. code-block:: python + + import asyncio + + async def consume(stream): + async for batch in stream: + ... + + streams = list(df.execute_stream_partitioned()) + await asyncio.gather(*(consume(s) for s in streams)) + +See :doc:`../io/arrow` for additional details on the Arrow interface. + HTML Rendering -------------- diff --git a/docs/source/user-guide/io/arrow.rst b/docs/source/user-guide/io/arrow.rst index d571aa99c..372c0d5af 100644 --- a/docs/source/user-guide/io/arrow.rst +++ b/docs/source/user-guide/io/arrow.rst @@ -60,14 +60,22 @@ Exporting from DataFusion DataFusion DataFrames implement ``__arrow_c_stream__`` PyCapsule interface, so any Python library that accepts these can import a DataFusion DataFrame directly. -.. warning:: - It is important to note that this will cause the DataFrame execution to happen, which may be - a time consuming task. That is, you will cause a - :py:func:`datafusion.dataframe.DataFrame.collect` operation call to occur. +.. note:: + Invoking ``__arrow_c_stream__`` still triggers execution of the underlying + query, but batches are yielded incrementally rather than materialized all at + once in memory. Consumers can process the stream as it arrives, avoiding the + memory overhead of a full + :py:func:`datafusion.dataframe.DataFrame.collect`. + + For an example of this streamed execution and its memory safety, see the + ``test_arrow_c_stream_large_dataset`` unit test in + :mod:`python.tests.test_io`. .. ipython:: python + from datafusion import col, lit + df = df.select((col("a") * lit(1.5)).alias("c"), lit("df").alias("d")) pa.table(df) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 181c29db4..ece8290c2 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -25,7 +25,9 @@ from typing import ( TYPE_CHECKING, Any, + AsyncIterator, Iterable, + Iterator, Literal, Optional, Union, @@ -42,7 +44,7 @@ from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal from datafusion.expr import Expr, SortExpr, sort_or_default from datafusion.plan import ExecutionPlan, LogicalPlan -from datafusion.record_batch import RecordBatchStream +from datafusion.record_batch import RecordBatch, RecordBatchStream if TYPE_CHECKING: import pathlib @@ -296,6 +298,9 @@ def __init__( class DataFrame: """Two dimensional table representation of data. + DataFrame objects are iterable; iterating over a DataFrame yields + :class:`datafusion.RecordBatch` instances lazily. + See :ref:`user_guide_concepts` in the online documentation for more information. """ @@ -312,7 +317,7 @@ def into_view(self) -> pa.Table: return self.df.into_view() def __getitem__(self, key: str | list[str]) -> DataFrame: - """Return a new :py:class`DataFrame` with the specified column or columns. + """Return a new :py:class:`DataFrame` with the specified column or columns. Args: key: Column name or list of column names to select. @@ -1105,21 +1110,54 @@ def unnest_columns(self, *columns: str, preserve_nulls: bool = True) -> DataFram return DataFrame(self.df.unnest_columns(columns, preserve_nulls=preserve_nulls)) def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: - """Export an Arrow PyCapsule Stream. + """Export the DataFrame as an Arrow C Stream. + + The DataFrame is executed using DataFusion's streaming APIs and exposed via + Arrow's C Stream interface. Record batches are produced incrementally, so the + full result set is never materialized in memory. - This will execute and collect the DataFrame. We will attempt to respect the - requested schema, but only trivial transformations will be applied such as only - returning the fields listed in the requested schema if their data types match - those in the DataFrame. + When ``requested_schema`` is provided, DataFusion applies only simple + projections such as selecting a subset of existing columns or reordering + them. Column renaming, computed expressions, or type coercion are not + supported through this interface. Args: - requested_schema: Attempt to provide the DataFrame using this schema. + requested_schema: Either a :py:class:`pyarrow.Schema` or an Arrow C + Schema capsule (``PyCapsule``) produced by + ``schema._export_to_c_capsule()``. The DataFrame will attempt to + align its output with the fields and order specified by this schema. Returns: - Arrow PyCapsule object. + Arrow ``PyCapsule`` object representing an ``ArrowArrayStream``. + + Examples: + >>> schema = df.schema() + >>> stream = df.__arrow_c_stream__(schema) + >>> capsule = schema._export_to_c_capsule() + >>> stream = df.__arrow_c_stream__(capsule) + + Notes: + The Arrow C Data Interface PyCapsule details are documented by Apache + Arrow and can be found at: + https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html """ + # ``DataFrame.__arrow_c_stream__`` in the Rust extension leverages + # ``execute_stream_partitioned`` under the hood to stream batches while + # preserving the original partition order. return self.df.__arrow_c_stream__(requested_schema) + def __iter__(self) -> Iterator[RecordBatch]: + """Return an iterator over this DataFrame's record batches.""" + return iter(self.execute_stream()) + + def __aiter__(self) -> AsyncIterator[RecordBatch]: + """Return an async iterator over this DataFrame's record batches. + + We're using __aiter__ because we support Python < 3.10 where aiter() is not + available. + """ + return self.execute_stream().__aiter__() + def transform(self, func: Callable[..., DataFrame], *args: Any) -> DataFrame: """Apply a function to the current DataFrame which returns another DataFrame. diff --git a/python/datafusion/record_batch.py b/python/datafusion/record_batch.py index 556eaa786..c24cde0ac 100644 --- a/python/datafusion/record_batch.py +++ b/python/datafusion/record_batch.py @@ -46,6 +46,26 @@ def to_pyarrow(self) -> pa.RecordBatch: """Convert to :py:class:`pa.RecordBatch`.""" return self.record_batch.to_pyarrow() + def __arrow_c_array__( + self, requested_schema: object | None = None + ) -> tuple[object, object]: + """Export the record batch via the Arrow C Data Interface. + + This allows zero-copy interchange with libraries that support the + `Arrow PyCapsule interface `_. + + Args: + requested_schema: Attempt to provide the record batch using this + schema. Only straightforward projections such as column + selection or reordering are applied. + + Returns: + Two Arrow PyCapsule objects representing the ``ArrowArray`` and + ``ArrowSchema``. + """ + return self.record_batch.__arrow_c_array__(requested_schema) + class RecordBatchStream: """This class represents a stream of record batches. @@ -63,19 +83,19 @@ def next(self) -> RecordBatch: return next(self) async def __anext__(self) -> RecordBatch: - """Async iterator function.""" + """Return the next :py:class:`RecordBatch` in the stream asynchronously.""" next_batch = await self.rbs.__anext__() return RecordBatch(next_batch) def __next__(self) -> RecordBatch: - """Iterator function.""" + """Return the next :py:class:`RecordBatch` in the stream.""" next_batch = next(self.rbs) return RecordBatch(next_batch) def __aiter__(self) -> typing_extensions.Self: - """Async iterator function.""" + """Return an asynchronous iterator over record batches.""" return self def __iter__(self) -> typing_extensions.Self: - """Iterator function.""" + """Return an iterator over record batches.""" return self diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 9548fbfe4..26ed7281d 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -17,7 +17,7 @@ import pyarrow as pa import pytest -from datafusion import SessionContext +from datafusion import DataFrame, SessionContext from pyarrow.csv import write_csv @@ -49,3 +49,12 @@ def database(ctx, tmp_path): delimiter=",", schema_infer_max_records=10, ) + + +@pytest.fixture +def fail_collect(monkeypatch): + def _fail_collect(self, *args, **kwargs): # pragma: no cover - failure path + msg = "collect should not be called" + raise AssertionError(msg) + + monkeypatch.setattr(DataFrame, "collect", _fail_collect) diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 343d32a92..262795530 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -29,6 +29,7 @@ DataFrame, ParquetColumnOptions, ParquetWriterOptions, + RecordBatch, SessionContext, WindowFrame, column, @@ -46,6 +47,8 @@ from datafusion.expr import Window from pyarrow.csv import write_csv +pa_cffi = pytest.importorskip("pyarrow.cffi") + MB = 1024 * 1024 @@ -377,6 +380,41 @@ def test_cast(df): assert df.schema() == expected +def test_iter_batches(df): + batches = [] + for batch in df: + batches.append(batch) # noqa: PERF402 + + # Delete DataFrame to ensure RecordBatches remain valid + del df + + assert len(batches) == 1 + + batch = batches[0] + assert isinstance(batch, RecordBatch) + pa_batch = batch.to_pyarrow() + assert pa_batch.column(0).to_pylist() == [1, 2, 3] + assert pa_batch.column(1).to_pylist() == [4, 5, 6] + assert pa_batch.column(2).to_pylist() == [8, 5, 8] + + +def test_iter_returns_datafusion_recordbatch(df): + for batch in df: + assert isinstance(batch, RecordBatch) + + +def test_execute_stream_basic(df): + stream = df.execute_stream() + batches = list(stream) + + assert len(batches) == 1 + assert isinstance(batches[0], RecordBatch) + pa_batch = batches[0].to_pyarrow() + assert pa_batch.column(0).to_pylist() == [1, 2, 3] + assert pa_batch.column(1).to_pylist() == [4, 5, 6] + assert pa_batch.column(2).to_pylist() == [8, 5, 8] + + def test_with_column_renamed(df): df = df.with_column("c", column("a") + column("b")).with_column_renamed("c", "sum") @@ -1312,7 +1350,7 @@ def test_execution_plan(aggregate_df): @pytest.mark.asyncio async def test_async_iteration_of_df(aggregate_df): rows_returned = 0 - async for batch in aggregate_df.execute_stream(): + async for batch in aggregate_df: assert batch is not None rows_returned += len(batch.to_pyarrow()[0]) @@ -1582,6 +1620,121 @@ def test_empty_to_arrow_table(df): assert set(pyarrow_table.column_names) == {"a", "b", "c"} +def test_iter_batches_dataframe(fail_collect): + ctx = SessionContext() + + batch1 = pa.record_batch([pa.array([1])], names=["a"]) + batch2 = pa.record_batch([pa.array([2])], names=["a"]) + df = ctx.create_dataframe([[batch1], [batch2]]) + + expected = [batch1, batch2] + results = [b.to_pyarrow() for b in df] + + assert len(results) == len(expected) + for exp in expected: + assert any(got.equals(exp) for got in results) + + +def test_arrow_c_stream_to_table_and_reader(fail_collect): + ctx = SessionContext() + + # Create a DataFrame with two separate record batches + batch1 = pa.record_batch([pa.array([1])], names=["a"]) + batch2 = pa.record_batch([pa.array([2])], names=["a"]) + df = ctx.create_dataframe([[batch1], [batch2]]) + + table = pa.Table.from_batches(batch.to_pyarrow() for batch in df) + batches = table.to_batches() + + assert len(batches) == 2 + expected = [batch1, batch2] + for exp in expected: + assert any(got.equals(exp) for got in batches) + assert table.schema == df.schema() + assert table.column("a").num_chunks == 2 + + reader = pa.RecordBatchReader.from_stream(df) + assert isinstance(reader, pa.RecordBatchReader) + reader_table = pa.Table.from_batches(reader) + expected = pa.Table.from_batches([batch1, batch2]) + assert reader_table.equals(expected) + + +def test_arrow_c_stream_order(): + ctx = SessionContext() + + batch1 = pa.record_batch([pa.array([1])], names=["a"]) + batch2 = pa.record_batch([pa.array([2])], names=["a"]) + + df = ctx.create_dataframe([[batch1, batch2]]) + + table = pa.Table.from_batches(batch.to_pyarrow() for batch in df) + expected = pa.Table.from_batches([batch1, batch2]) + + assert table.equals(expected) + col = table.column("a") + assert col.chunk(0)[0].as_py() == 1 + assert col.chunk(1)[0].as_py() == 2 + + +def test_arrow_c_stream_schema_selection(fail_collect): + ctx = SessionContext() + + batch = pa.RecordBatch.from_arrays( + [ + pa.array([1, 2]), + pa.array([3, 4]), + pa.array([5, 6]), + ], + names=["a", "b", "c"], + ) + df = ctx.create_dataframe([[batch]]) + + requested_schema = pa.schema([("c", pa.int64()), ("a", pa.int64())]) + + c_schema = pa_cffi.ffi.new("struct ArrowSchema*") + address = int(pa_cffi.ffi.cast("uintptr_t", c_schema)) + requested_schema._export_to_c(address) + capsule_new = ctypes.pythonapi.PyCapsule_New + capsule_new.restype = ctypes.py_object + capsule_new.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p] + + reader = pa.RecordBatchReader.from_stream(df, schema=requested_schema) + + assert reader.schema == requested_schema + + batches = list(reader) + + assert len(batches) == 1 + expected_batch = pa.record_batch( + [pa.array([5, 6]), pa.array([1, 2])], names=["c", "a"] + ) + assert batches[0].equals(expected_batch) + + +def test_arrow_c_stream_schema_mismatch(fail_collect): + ctx = SessionContext() + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2]), pa.array([3, 4])], names=["a", "b"] + ) + df = ctx.create_dataframe([[batch]]) + + bad_schema = pa.schema([("a", pa.string())]) + + c_schema = pa_cffi.ffi.new("struct ArrowSchema*") + address = int(pa_cffi.ffi.cast("uintptr_t", c_schema)) + bad_schema._export_to_c(address) + + capsule_new = ctypes.pythonapi.PyCapsule_New + capsule_new.restype = ctypes.py_object + capsule_new.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p] + bad_capsule = capsule_new(ctypes.c_void_p(address), b"arrow_schema", None) + + with pytest.raises(Exception, match="Fail to merge schema"): + df.__arrow_c_stream__(bad_capsule) + + def test_to_pylist(df): # Convert datafusion dataframe to Python list pylist = df.to_pylist() @@ -2666,6 +2819,110 @@ def trigger_interrupt(): interrupt_thread.join(timeout=1.0) +def test_arrow_c_stream_interrupted(): + """__arrow_c_stream__ responds to ``KeyboardInterrupt`` signals. + + Similar to ``test_collect_interrupted`` this test issues a long running + query, but consumes the results via ``__arrow_c_stream__``. It then raises + ``KeyboardInterrupt`` in the main thread and verifies that the stream + iteration stops promptly with the appropriate exception. + """ + + ctx = SessionContext() + + batches = [] + for i in range(10): + batch = pa.RecordBatch.from_arrays( + [ + pa.array(list(range(i * 1000, (i + 1) * 1000))), + pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 1000)]), + ], + names=["a", "b"], + ) + batches.append(batch) + + ctx.register_record_batches("t1", [batches]) + ctx.register_record_batches("t2", [batches]) + + df = ctx.sql( + """ + WITH t1_expanded AS ( + SELECT + a, + b, + CAST(a AS DOUBLE) / 1.5 AS c, + CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d + FROM t1 + CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5) + ), + t2_expanded AS ( + SELECT + a, + b, + CAST(a AS DOUBLE) * 2.5 AS e, + CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f + FROM t2 + CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5) + ) + SELECT + t1.a, t1.b, t1.c, t1.d, + t2.a AS a2, t2.b AS b2, t2.e, t2.f + FROM t1_expanded t1 + JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100 + WHERE t1.a > 100 AND t2.a > 100 + """ + ) + + reader = pa.RecordBatchReader.from_stream(df) + + interrupted = False + interrupt_error = None + query_started = threading.Event() + max_wait_time = 5.0 + + def trigger_interrupt(): + start_time = time.time() + while not query_started.is_set(): + time.sleep(0.1) + if time.time() - start_time > max_wait_time: + msg = f"Query did not start within {max_wait_time} seconds" + raise RuntimeError(msg) + + thread_id = threading.main_thread().ident + if thread_id is None: + msg = "Cannot get main thread ID" + raise RuntimeError(msg) + + exception = ctypes.py_object(KeyboardInterrupt) + res = ctypes.pythonapi.PyThreadState_SetAsyncExc( + ctypes.c_long(thread_id), exception + ) + if res != 1: + ctypes.pythonapi.PyThreadState_SetAsyncExc( + ctypes.c_long(thread_id), ctypes.py_object(0) + ) + msg = "Failed to raise KeyboardInterrupt in main thread" + raise RuntimeError(msg) + + interrupt_thread = threading.Thread(target=trigger_interrupt) + interrupt_thread.daemon = True + interrupt_thread.start() + + try: + query_started.set() + # consume the reader which should block and be interrupted + reader.read_all() + except KeyboardInterrupt: + interrupted = True + except Exception as e: # pragma: no cover - unexpected errors + interrupt_error = e + + if not interrupted: + pytest.fail(f"Stream was not interrupted; got error: {interrupt_error}") + + interrupt_thread.join(timeout=1.0) + + def test_show_select_where_no_rows(capsys) -> None: ctx = SessionContext() df = ctx.sql("SELECT 1 WHERE 1=0") diff --git a/python/tests/test_io.py b/python/tests/test_io.py index 7ca509689..9f56f74d7 100644 --- a/python/tests/test_io.py +++ b/python/tests/test_io.py @@ -14,12 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + from pathlib import Path import pyarrow as pa +import pytest from datafusion import column from datafusion.io import read_avro, read_csv, read_json, read_parquet +from .utils import range_table + def test_read_json_global_ctx(ctx): path = Path(__file__).parent.resolve() @@ -92,3 +96,43 @@ def test_read_avro(): path = Path.cwd() / "testing/data/avro/alltypes_plain.avro" avro_df = read_avro(path=path) assert avro_df is not None + + +def test_arrow_c_stream_large_dataset(ctx): + """DataFrame streaming yields batches incrementally using Arrow APIs. + + This test constructs a DataFrame that would be far larger than available + memory if materialized. Use the public API + ``pa.RecordBatchReader.from_stream(df)`` (which is same as + ``pa.RecordBatchReader._import_from_c_capsule(df.__arrow_c_stream__())``) + to read record batches incrementally without collecting the full dataset, + so reading a handful of batches should not exhaust process memory. + """ + # Create a very large DataFrame using range; this would be terabytes if collected + df = range_table(ctx, 0, 1 << 40) + + reader = pa.RecordBatchReader.from_stream(df) + + # Track RSS before consuming batches + # RSS is a practical measure of RAM usage visible to the OS. It excludes memory + # that has been swapped out and provides a simple cross-platform-ish indicator + # (psutil normalizes per-OS sources). + psutil = pytest.importorskip("psutil") + process = psutil.Process() + start_rss = process.memory_info().rss + + for _ in range(5): + batch = reader.read_next_batch() + assert batch is not None + assert len(batch) > 0 + current_rss = process.memory_info().rss + # Ensure memory usage hasn't grown substantially (>50MB) + assert current_rss - start_rss < 50 * 1024 * 1024 + + +def test_table_from_arrow_c_stream(ctx, fail_collect): + df = range_table(ctx, 0, 10) + + table = pa.table(df) + assert table.shape == (10, 1) + assert table.column_names == ["value"] diff --git a/python/tests/utils.py b/python/tests/utils.py new file mode 100644 index 000000000..00efb6555 --- /dev/null +++ b/python/tests/utils.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Testing-only helpers for datafusion-python. + +This module contains utilities used by the test-suite that should not be +exposed as part of the public API. Keep the implementation minimal and +documented so reviewers can easily see it's test-only. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from datafusion import DataFrame + from datafusion.context import SessionContext + + +def range_table( + ctx: SessionContext, + start: int, + stop: int | None = None, + step: int = 1, + partitions: int | None = None, +) -> DataFrame: + """Create a DataFrame containing a sequence of numbers using SQL RANGE. + + This mirrors the previous ``SessionContext.range`` convenience method but + lives in a testing-only module so it doesn't expand the public surface. + + Args: + ctx: SessionContext instance to run the SQL against. + start: Starting value for the sequence or exclusive stop when ``stop`` + is ``None``. + stop: Exclusive upper bound of the sequence. + step: Increment between successive values. + partitions: Optional number of partitions for the generated data. + + Returns: + DataFrame produced by the range table function. + """ + if stop is None: + start, stop = 0, start + + parts = f", {int(partitions)}" if partitions is not None else "" + sql = f"SELECT * FROM range({int(start)}, {int(stop)}, {int(step)}{parts})" + return ctx.sql(sql) diff --git a/src/context.rs b/src/context.rs index 36133a33d..561fb37fa 100644 --- a/src/context.rs +++ b/src/context.rs @@ -34,7 +34,7 @@ use pyo3::prelude::*; use crate::catalog::{PyCatalog, PyTable, RustWrappedPyCatalogProvider}; use crate::dataframe::PyDataFrame; use crate::dataset::Dataset; -use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult}; +use crate::errors::{py_datafusion_err, PyDataFusionResult}; use crate::expr::sort_expr::PySortExpr; use crate::physical_plan::PyExecutionPlan; use crate::record_batch::PyRecordBatchStream; @@ -45,7 +45,7 @@ use crate::udaf::PyAggregateUDF; use crate::udf::PyScalarUDF; use crate::udtf::PyTableFunction; use crate::udwf::PyWindowUDF; -use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future}; +use crate::utils::{get_global_ctx, spawn_future, validate_pycapsule, wait_for_future}; use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; @@ -66,7 +66,6 @@ use datafusion::execution::disk_manager::DiskManagerMode; use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool}; use datafusion::execution::options::ReadOptions; use datafusion::execution::runtime_env::RuntimeEnvBuilder; -use datafusion::physical_plan::SendableRecordBatchStream; use datafusion::prelude::{ AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions, }; @@ -74,7 +73,6 @@ use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvid use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType}; use pyo3::IntoPyObjectExt; -use tokio::task::JoinHandle; /// Configuration options for a SessionContext #[pyclass(name = "SessionConfig", module = "datafusion", subclass)] @@ -1132,12 +1130,8 @@ impl PySessionContext { py: Python, ) -> PyDataFusionResult { let ctx: TaskContext = TaskContext::from(&self.ctx.state()); - // create a Tokio runtime to run the async code - let rt = &get_tokio_runtime().0; let plan = plan.plan.clone(); - let fut: JoinHandle> = - rt.spawn(async move { plan.execute(part, Arc::new(ctx)) }); - let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })???; + let stream = spawn_future(py, async move { plan.execute(part, Arc::new(ctx)) })?; Ok(PyRecordBatchStream::new(stream)) } } diff --git a/src/dataframe.rs b/src/dataframe.rs index 46fba137c..db41a204f 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -15,17 +15,18 @@ // specific language governing permissions and limitations // under the License. +use cstr::cstr; use std::collections::HashMap; -use std::ffi::CString; +use std::ffi::{CStr, CString}; use std::sync::Arc; -use arrow::array::{new_null_array, RecordBatch, RecordBatchIterator, RecordBatchReader}; +use arrow::array::{new_null_array, RecordBatch, RecordBatchReader}; use arrow::compute::can_cast_types; use arrow::error::ArrowError; use arrow::ffi::FFI_ArrowSchema; use arrow::ffi_stream::FFI_ArrowArrayStream; use arrow::pyarrow::FromPyArrow; -use datafusion::arrow::datatypes::Schema; +use datafusion::arrow::datatypes::{Schema, SchemaRef}; use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; use datafusion::arrow::util::pretty; use datafusion::common::UnnestOptions; @@ -42,22 +43,26 @@ use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods}; -use tokio::task::JoinHandle; +use pyo3::PyErr; use crate::catalog::PyTable; -use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError}; +use crate::errors::{py_datafusion_err, PyDataFusionError}; use crate::expr::sort_expr::to_sort_expressions; use crate::physical_plan::PyExecutionPlan; -use crate::record_batch::PyRecordBatchStream; +use crate::record_batch::{poll_next_batch, PyRecordBatchStream}; use crate::sql::logical::PyLogicalPlan; use crate::utils::{ - get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, validate_pycapsule, wait_for_future, + get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, spawn_future, validate_pycapsule, + wait_for_future, }; use crate::{ errors::PyDataFusionResult, expr::{sort_expr::PySortExpr, PyExpr}, }; +/// File-level static CStr for the Arrow array stream capsule name. +static ARROW_ARRAY_STREAM_NAME: &CStr = cstr!("arrow_array_stream"); + // https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116 // - we have not decided on the table_provider approach yet // this is an interim implementation @@ -354,6 +359,63 @@ impl PyDataFrame { } } +/// Synchronous wrapper around partitioned [`SendableRecordBatchStream`]s used +/// for the `__arrow_c_stream__` implementation. +/// +/// It drains each partition's stream sequentially, yielding record batches in +/// their original partition order. When a `projection` is set, each batch is +/// converted via `record_batch_into_schema` to apply schema changes per batch. +struct PartitionedDataFrameStreamReader { + streams: Vec, + schema: SchemaRef, + projection: Option, + current: usize, +} + +impl Iterator for PartitionedDataFrameStreamReader { + type Item = Result; + + fn next(&mut self) -> Option { + while self.current < self.streams.len() { + let stream = &mut self.streams[self.current]; + let fut = poll_next_batch(stream); + let result = Python::with_gil(|py| wait_for_future(py, fut)); + + match result { + Ok(Ok(Some(batch))) => { + let batch = if let Some(ref schema) = self.projection { + match record_batch_into_schema(batch, schema.as_ref()) { + Ok(b) => b, + Err(e) => return Some(Err(e)), + } + } else { + batch + }; + return Some(Ok(batch)); + } + Ok(Ok(None)) => { + self.current += 1; + continue; + } + Ok(Err(e)) => { + return Some(Err(ArrowError::ExternalError(Box::new(e)))); + } + Err(e) => { + return Some(Err(ArrowError::ExternalError(Box::new(e)))); + } + } + } + + None + } +} + +impl RecordBatchReader for PartitionedDataFrameStreamReader { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + #[pymethods] impl PyDataFrame { /// Enable selection for `df[col]`, `df[col1, col2, col3]`, and `df[[col1, col2, col3]]` @@ -879,8 +941,11 @@ impl PyDataFrame { py: Python<'py>, requested_schema: Option>, ) -> PyDataFusionResult> { - let mut batches = wait_for_future(py, self.df.as_ref().clone().collect())??; + let df = self.df.as_ref().clone(); + let streams = spawn_future(py, async move { df.execute_stream_partitioned().await })?; + let mut schema: Schema = self.df.schema().to_owned().into(); + let mut projection: Option = None; if let Some(schema_capsule) = requested_schema { validate_pycapsule(&schema_capsule, "arrow_schema")?; @@ -889,44 +954,38 @@ impl PyDataFrame { let desired_schema = Schema::try_from(schema_ptr)?; schema = project_schema(schema, desired_schema)?; - - batches = batches - .into_iter() - .map(|record_batch| record_batch_into_schema(record_batch, &schema)) - .collect::, ArrowError>>()?; + projection = Some(Arc::new(schema.clone())); } - let batches_wrapped = batches.into_iter().map(Ok); + let schema_ref = Arc::new(schema.clone()); - let reader = RecordBatchIterator::new(batches_wrapped, Arc::new(schema)); + let reader = PartitionedDataFrameStreamReader { + streams, + schema: schema_ref, + projection, + current: 0, + }; let reader: Box = Box::new(reader); - let ffi_stream = FFI_ArrowArrayStream::new(reader); - let stream_capsule_name = CString::new("arrow_array_stream").unwrap(); - PyCapsule::new(py, ffi_stream, Some(stream_capsule_name)).map_err(PyDataFusionError::from) + // Create the Arrow stream and wrap it in a PyCapsule. The default + // destructor provided by PyO3 will drop the stream unless ownership is + // transferred to PyArrow during import. + let stream = FFI_ArrowArrayStream::new(reader); + let name = CString::new(ARROW_ARRAY_STREAM_NAME.to_bytes()).unwrap(); + let capsule = PyCapsule::new(py, stream, Some(name))?; + Ok(capsule) } fn execute_stream(&self, py: Python) -> PyDataFusionResult { - // create a Tokio runtime to run the async code - let rt = &get_tokio_runtime().0; let df = self.df.as_ref().clone(); - let fut: JoinHandle> = - rt.spawn(async move { df.execute_stream().await }); - let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })???; + let stream = spawn_future(py, async move { df.execute_stream().await })?; Ok(PyRecordBatchStream::new(stream)) } fn execute_stream_partitioned(&self, py: Python) -> PyResult> { - // create a Tokio runtime to run the async code - let rt = &get_tokio_runtime().0; let df = self.df.as_ref().clone(); - let fut: JoinHandle>> = - rt.spawn(async move { df.execute_stream_partitioned().await }); - let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })? - .map_err(py_datafusion_err)? - .map_err(py_datafusion_err)?; - - Ok(stream.into_iter().map(PyRecordBatchStream::new).collect()) + let streams = spawn_future(py, async move { df.execute_stream_partitioned().await })?; + Ok(streams.into_iter().map(PyRecordBatchStream::new).collect()) } /// Convert to pandas dataframe with pyarrow @@ -1026,7 +1085,11 @@ fn project_schema(from_schema: Schema, to_schema: Schema) -> Result>` form. +pub(crate) async fn poll_next_batch( + stream: &mut SendableRecordBatchStream, +) -> datafusion::error::Result> { + stream.next().await.transpose() +} + async fn next_stream( stream: Arc>, sync: bool, ) -> PyResult { let mut stream = stream.lock().await; - match stream.next().await { - Some(Ok(batch)) => Ok(batch.into()), - Some(Err(e)) => Err(PyDataFusionError::from(e))?, - None => { + match poll_next_batch(&mut stream).await { + Ok(Some(batch)) => Ok(batch.into()), + Ok(None) => { // Depending on whether the iteration is sync or not, we raise either a // StopIteration or a StopAsyncIteration if sync { @@ -101,5 +107,6 @@ async fn next_stream( Err(PyStopAsyncIteration::new_err("stream exhausted")) } } + Err(e) => Err(PyDataFusionError::from(e))?, } } diff --git a/src/utils.rs b/src/utils.rs index 3b30de5de..483095d3c 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -17,7 +17,7 @@ use crate::{ common::data_type::PyScalarValue, - errors::{PyDataFusionError, PyDataFusionResult}, + errors::{to_datafusion_err, PyDataFusionError, PyDataFusionResult}, TokioRuntime, }; use datafusion::{ @@ -26,7 +26,7 @@ use datafusion::{ use pyo3::prelude::*; use pyo3::{exceptions::PyValueError, types::PyCapsule}; use std::{future::Future, sync::OnceLock, time::Duration}; -use tokio::{runtime::Runtime, time::sleep}; +use tokio::{runtime::Runtime, task::JoinHandle, time::sleep}; /// Utility to get the Tokio Runtime from Python #[inline] pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime { @@ -84,6 +84,35 @@ where }) } +/// Spawn a [`Future`] on the Tokio runtime and wait for completion +/// while respecting Python signal handling. +pub(crate) fn spawn_future(py: Python, fut: F) -> PyDataFusionResult +where + F: Future> + Send + 'static, + T: Send + 'static, +{ + let rt = &get_tokio_runtime().0; + let handle: JoinHandle> = rt.spawn(fut); + // Wait for the join handle while respecting Python signal handling. + // We handle errors in two steps so `?` maps the error types correctly: + // 1) convert any Python-related error from `wait_for_future` into `PyDataFusionError` + // 2) convert any DataFusion error (inner result) into `PyDataFusionError` + let inner_result = wait_for_future(py, async { + // handle.await yields `Result, JoinError>` + // map JoinError into a DataFusion error so the async block returns + // `datafusion::common::Result` (i.e. Result) + match handle.await { + Ok(inner) => inner, + Err(join_err) => Err(to_datafusion_err(join_err)), + } + })?; // converts PyErr -> PyDataFusionError + + // `inner_result` is `datafusion::common::Result`; use `?` to convert + // the inner DataFusion error into `PyDataFusionError` via `From` and + // return the inner `T` on success. + Ok(inner_result?) +} + pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult { Ok(match value { "immutable" => Volatility::Immutable,