Skip to content

Commit f5fdf59

Browse files
timsaucerCopilot
andauthored
feat: expose DataFrame.write_table (apache#1264)
* Initial commit for dataframe write_table * Add dataframe writer options and docstring * add csv write unit test * add docstrings * more testing around writer options * Minor docstring change Co-authored-by: Copilot <[email protected]> * Format docstring so it renders better * whitespace * Resolve error on insert operation and add unit test coverage * mark classes as frozen --------- Co-authored-by: Copilot <[email protected]>
1 parent 6f3b1ca commit f5fdf59

File tree

5 files changed

+339
-35
lines changed

5 files changed

+339
-35
lines changed

python/datafusion/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,13 @@
4444
SessionContext,
4545
SQLOptions,
4646
)
47-
from .dataframe import DataFrame, ParquetColumnOptions, ParquetWriterOptions
47+
from .dataframe import (
48+
DataFrame,
49+
DataFrameWriteOptions,
50+
InsertOp,
51+
ParquetColumnOptions,
52+
ParquetWriterOptions,
53+
)
4854
from .dataframe_formatter import configure_formatter
4955
from .expr import Expr, WindowFrame
5056
from .io import read_avro, read_csv, read_json, read_parquet
@@ -71,9 +77,11 @@
7177
"Config",
7278
"DFSchema",
7379
"DataFrame",
80+
"DataFrameWriteOptions",
7481
"Database",
7582
"ExecutionPlan",
7683
"Expr",
84+
"InsertOp",
7785
"LogicalPlan",
7886
"ParquetColumnOptions",
7987
"ParquetWriterOptions",

python/datafusion/dataframe.py

Lines changed: 120 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,13 @@
3939
from typing_extensions import deprecated # Python 3.12
4040

4141
from datafusion._internal import DataFrame as DataFrameInternal
42+
from datafusion._internal import DataFrameWriteOptions as DataFrameWriteOptionsInternal
43+
from datafusion._internal import InsertOp as InsertOpInternal
4244
from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4345
from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
4446
from datafusion.expr import (
4547
Expr,
48+
SortExpr,
4649
SortKey,
4750
ensure_expr,
4851
ensure_expr_list,
@@ -939,21 +942,31 @@ def except_all(self, other: DataFrame) -> DataFrame:
939942
"""
940943
return DataFrame(self.df.except_all(other.df))
941944

942-
def write_csv(self, path: str | pathlib.Path, with_header: bool = False) -> None:
945+
def write_csv(
946+
self,
947+
path: str | pathlib.Path,
948+
with_header: bool = False,
949+
write_options: DataFrameWriteOptions | None = None,
950+
) -> None:
943951
"""Execute the :py:class:`DataFrame` and write the results to a CSV file.
944952
945953
Args:
946954
path: Path of the CSV file to write.
947955
with_header: If true, output the CSV header row.
956+
write_options: Options that impact how the DataFrame is written.
948957
"""
949-
self.df.write_csv(str(path), with_header)
958+
raw_write_options = (
959+
write_options._raw_write_options if write_options is not None else None
960+
)
961+
self.df.write_csv(str(path), with_header, raw_write_options)
950962

951963
@overload
952964
def write_parquet(
953965
self,
954966
path: str | pathlib.Path,
955967
compression: str,
956968
compression_level: int | None = None,
969+
write_options: DataFrameWriteOptions | None = None,
957970
) -> None: ...
958971

959972
@overload
@@ -962,6 +975,7 @@ def write_parquet(
962975
path: str | pathlib.Path,
963976
compression: Compression = Compression.ZSTD,
964977
compression_level: int | None = None,
978+
write_options: DataFrameWriteOptions | None = None,
965979
) -> None: ...
966980

967981
@overload
@@ -970,31 +984,38 @@ def write_parquet(
970984
path: str | pathlib.Path,
971985
compression: ParquetWriterOptions,
972986
compression_level: None = None,
987+
write_options: DataFrameWriteOptions | None = None,
973988
) -> None: ...
974989

975990
def write_parquet(
976991
self,
977992
path: str | pathlib.Path,
978993
compression: Union[str, Compression, ParquetWriterOptions] = Compression.ZSTD,
979994
compression_level: int | None = None,
995+
write_options: DataFrameWriteOptions | None = None,
980996
) -> None:
981997
"""Execute the :py:class:`DataFrame` and write the results to a Parquet file.
982998
999+
Available compression types are:
1000+
1001+
- "uncompressed": No compression.
1002+
- "snappy": Snappy compression.
1003+
- "gzip": Gzip compression.
1004+
- "brotli": Brotli compression.
1005+
- "lz4": LZ4 compression.
1006+
- "lz4_raw": LZ4_RAW compression.
1007+
- "zstd": Zstandard compression.
1008+
1009+
LZO compression is not yet implemented in arrow-rs and is therefore
1010+
excluded.
1011+
9831012
Args:
9841013
path: Path of the Parquet file to write.
9851014
compression: Compression type to use. Default is "ZSTD".
986-
Available compression types are:
987-
- "uncompressed": No compression.
988-
- "snappy": Snappy compression.
989-
- "gzip": Gzip compression.
990-
- "brotli": Brotli compression.
991-
- "lz4": LZ4 compression.
992-
- "lz4_raw": LZ4_RAW compression.
993-
- "zstd": Zstandard compression.
994-
Note: LZO is not yet implemented in arrow-rs and is therefore excluded.
9951015
compression_level: Compression level to use. For ZSTD, the
9961016
recommended range is 1 to 22, with the default being 4. Higher levels
9971017
provide better compression but slower speed.
1018+
write_options: Options that impact how the DataFrame is written.
9981019
"""
9991020
if isinstance(compression, ParquetWriterOptions):
10001021
if compression_level is not None:
@@ -1012,10 +1033,21 @@ def write_parquet(
10121033
):
10131034
compression_level = compression.get_default_level()
10141035

1015-
self.df.write_parquet(str(path), compression.value, compression_level)
1036+
raw_write_options = (
1037+
write_options._raw_write_options if write_options is not None else None
1038+
)
1039+
self.df.write_parquet(
1040+
str(path),
1041+
compression.value,
1042+
compression_level,
1043+
raw_write_options,
1044+
)
10161045

10171046
def write_parquet_with_options(
1018-
self, path: str | pathlib.Path, options: ParquetWriterOptions
1047+
self,
1048+
path: str | pathlib.Path,
1049+
options: ParquetWriterOptions,
1050+
write_options: DataFrameWriteOptions | None = None,
10191051
) -> None:
10201052
"""Execute the :py:class:`DataFrame` and write the results to a Parquet file.
10211053
@@ -1024,6 +1056,7 @@ def write_parquet_with_options(
10241056
Args:
10251057
path: Path of the Parquet file to write.
10261058
options: Sets the writer parquet options (see `ParquetWriterOptions`).
1059+
write_options: Options that impact how the DataFrame is written.
10271060
"""
10281061
options_internal = ParquetWriterOptionsInternal(
10291062
options.data_pagesize_limit,
@@ -1060,19 +1093,45 @@ def write_parquet_with_options(
10601093
bloom_filter_ndv=opts.bloom_filter_ndv,
10611094
)
10621095

1096+
raw_write_options = (
1097+
write_options._raw_write_options if write_options is not None else None
1098+
)
10631099
self.df.write_parquet_with_options(
10641100
str(path),
10651101
options_internal,
10661102
column_specific_options_internal,
1103+
raw_write_options,
10671104
)
10681105

1069-
def write_json(self, path: str | pathlib.Path) -> None:
1106+
def write_json(
1107+
self,
1108+
path: str | pathlib.Path,
1109+
write_options: DataFrameWriteOptions | None = None,
1110+
) -> None:
10701111
"""Execute the :py:class:`DataFrame` and write the results to a JSON file.
10711112
10721113
Args:
10731114
path: Path of the JSON file to write.
1115+
write_options: Options that impact how the DataFrame is written.
1116+
"""
1117+
raw_write_options = (
1118+
write_options._raw_write_options if write_options is not None else None
1119+
)
1120+
self.df.write_json(str(path), write_options=raw_write_options)
1121+
1122+
def write_table(
1123+
self, table_name: str, write_options: DataFrameWriteOptions | None = None
1124+
) -> None:
1125+
"""Execute the :py:class:`DataFrame` and write the results to a table.
1126+
1127+
The table must be registered with the session to perform this operation.
1128+
Not all table providers support writing operations. See the individual
1129+
implementations for details.
10741130
"""
1075-
self.df.write_json(str(path))
1131+
raw_write_options = (
1132+
write_options._raw_write_options if write_options is not None else None
1133+
)
1134+
self.df.write_table(table_name, raw_write_options)
10761135

10771136
def to_arrow_table(self) -> pa.Table:
10781137
"""Execute the :py:class:`DataFrame` and convert it into an Arrow Table.
@@ -1220,3 +1279,49 @@ def fill_null(self, value: Any, subset: list[str] | None = None) -> DataFrame:
12201279
- For columns not in subset, the original column is kept unchanged
12211280
"""
12221281
return DataFrame(self.df.fill_null(value, subset))
1282+
1283+
1284+
class InsertOp(Enum):
1285+
"""Insert operation mode.
1286+
1287+
These modes are used by the table writing feature to define how record
1288+
batches should be written to a table.
1289+
"""
1290+
1291+
APPEND = InsertOpInternal.APPEND
1292+
"""Appends new rows to the existing table without modifying any existing rows."""
1293+
1294+
REPLACE = InsertOpInternal.REPLACE
1295+
"""Replace existing rows that collide with the inserted rows.
1296+
1297+
Replacement is typically based on a unique key or primary key.
1298+
"""
1299+
1300+
OVERWRITE = InsertOpInternal.OVERWRITE
1301+
"""Overwrites all existing rows in the table with the new rows."""
1302+
1303+
1304+
class DataFrameWriteOptions:
1305+
"""Writer options for DataFrame.
1306+
1307+
There is no guarantee the table provider supports all writer options.
1308+
See the individual implementation and documentation for details.
1309+
"""
1310+
1311+
def __init__(
1312+
self,
1313+
insert_operation: InsertOp | None = None,
1314+
single_file_output: bool = False,
1315+
partition_by: str | Sequence[str] | None = None,
1316+
sort_by: Expr | SortExpr | Sequence[Expr] | Sequence[SortExpr] | None = None,
1317+
) -> None:
1318+
"""Instantiate writer options for DataFrame."""
1319+
if isinstance(partition_by, str):
1320+
partition_by = [partition_by]
1321+
1322+
sort_by_raw = sort_list_to_raw_sort_list(sort_by)
1323+
insert_op = insert_operation.value if insert_operation is not None else None
1324+
1325+
self._raw_write_options = DataFrameWriteOptionsInternal(
1326+
insert_op, single_file_output, partition_by, sort_by_raw
1327+
)

python/tests/test_dataframe.py

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
import ctypes
1818
import datetime
19+
import itertools
1920
import os
2021
import re
2122
import threading
@@ -27,6 +28,7 @@
2728
import pytest
2829
from datafusion import (
2930
DataFrame,
31+
InsertOp,
3032
ParquetColumnOptions,
3133
ParquetWriterOptions,
3234
SessionContext,
@@ -40,6 +42,7 @@
4042
from datafusion import (
4143
functions as f,
4244
)
45+
from datafusion.dataframe import DataFrameWriteOptions
4346
from datafusion.dataframe_formatter import (
4447
DataFrameHtmlFormatter,
4548
configure_formatter,
@@ -58,9 +61,7 @@ def ctx():
5861

5962

6063
@pytest.fixture
61-
def df():
62-
ctx = SessionContext()
63-
64+
def df(ctx):
6465
# create a RecordBatch and a new DataFrame from it
6566
batch = pa.RecordBatch.from_arrays(
6667
[pa.array([1, 2, 3]), pa.array([4, 5, 6]), pa.array([8, 5, 8])],
@@ -1830,6 +1831,69 @@ def test_write_csv(ctx, df, tmp_path, path_to_str):
18301831
assert result == expected
18311832

18321833

1834+
def generate_test_write_params() -> list[tuple]:
1835+
# Overwrite and Replace are not implemented for many table writers
1836+
insert_ops = [InsertOp.APPEND, None]
1837+
sort_by_cases = [
1838+
(None, [1, 2, 3], "unsorted"),
1839+
(column("c"), [2, 1, 3], "single_column_expr"),
1840+
(column("a").sort(ascending=False), [3, 2, 1], "single_sort_expr"),
1841+
([column("c"), column("b")], [2, 1, 3], "list_col_expr"),
1842+
(
1843+
[column("c").sort(ascending=False), column("b").sort(ascending=False)],
1844+
[3, 1, 2],
1845+
"list_sort_expr",
1846+
),
1847+
]
1848+
1849+
formats = ["csv", "json", "parquet", "table"]
1850+
1851+
return [
1852+
pytest.param(
1853+
output_format,
1854+
insert_op,
1855+
sort_by,
1856+
expected_a,
1857+
id=f"{output_format}_{test_id}",
1858+
)
1859+
for output_format, insert_op, (
1860+
sort_by,
1861+
expected_a,
1862+
test_id,
1863+
) in itertools.product(formats, insert_ops, sort_by_cases)
1864+
]
1865+
1866+
1867+
@pytest.mark.parametrize(
1868+
("output_format", "insert_op", "sort_by", "expected_a"),
1869+
generate_test_write_params(),
1870+
)
1871+
def test_write_files_with_options(
1872+
ctx, df, tmp_path, output_format, insert_op, sort_by, expected_a
1873+
) -> None:
1874+
write_options = DataFrameWriteOptions(insert_operation=insert_op, sort_by=sort_by)
1875+
1876+
if output_format == "csv":
1877+
df.write_csv(tmp_path, with_header=True, write_options=write_options)
1878+
ctx.register_csv("test_table", tmp_path)
1879+
elif output_format == "json":
1880+
df.write_json(tmp_path, write_options=write_options)
1881+
ctx.register_json("test_table", tmp_path)
1882+
elif output_format == "parquet":
1883+
df.write_parquet(tmp_path, write_options=write_options)
1884+
ctx.register_parquet("test_table", tmp_path)
1885+
elif output_format == "table":
1886+
batch = pa.RecordBatch.from_arrays([[], [], []], schema=df.schema())
1887+
ctx.register_record_batches("test_table", [[batch]])
1888+
ctx.table("test_table").show()
1889+
df.write_table("test_table", write_options=write_options)
1890+
1891+
result = ctx.table("test_table").to_pydict()["a"]
1892+
ctx.table("test_table").show()
1893+
1894+
assert result == expected_a
1895+
1896+
18331897
@pytest.mark.parametrize("path_to_str", [True, False])
18341898
def test_write_json(ctx, df, tmp_path, path_to_str):
18351899
path = str(tmp_path) if path_to_str else tmp_path
@@ -2322,6 +2386,25 @@ def test_write_parquet_options_error(df, tmp_path):
23222386
df.write_parquet(str(tmp_path), options, compression_level=1)
23232387

23242388

2389+
def test_write_table(ctx, df):
2390+
batch = pa.RecordBatch.from_arrays(
2391+
[pa.array([1, 2, 3])],
2392+
names=["a"],
2393+
)
2394+
2395+
ctx.register_record_batches("t", [[batch]])
2396+
2397+
df = ctx.table("t").with_column("a", column("a") * literal(-1))
2398+
2399+
ctx.table("t").show()
2400+
2401+
df.write_table("t")
2402+
result = ctx.table("t").sort(column("a")).collect()[0][0].to_pylist()
2403+
expected = [-3, -2, -1, 1, 2, 3]
2404+
2405+
assert result == expected
2406+
2407+
23252408
def test_dataframe_export(df) -> None:
23262409
# Guarantees that we have the canonical implementation
23272410
# reading our dataframe export

0 commit comments

Comments
 (0)