Skip to content

Commit 68f50ef

Browse files
committed
feat: add DataFrame.to_pandas_batches() to download large DataFrame objects
1 parent 6e28da3 commit 68f50ef

File tree

5 files changed

+213
-8
lines changed

5 files changed

+213
-8
lines changed

bigframes/core/blocks.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,27 @@ def to_pandas(
412412
)
413413
return df, query_job
414414

415+
def to_pandas_batches(self):
416+
"""Download results one message at a time."""
417+
dtypes = dict(zip(self.index_columns, self.index_dtypes))
418+
dtypes.update(zip(self.value_columns, self.dtypes))
419+
results_iterator, _ = self._expr.start_query()
420+
for arrow_table in results_iterator.to_arrow_iterable(
421+
bqstorage_client=self._expr._session.bqstoragereadclient
422+
):
423+
df = bigframes.session._io.pandas.arrow_to_pandas(arrow_table, dtypes)
424+
self._copy_index_to_pandas(df)
425+
yield df
426+
427+
def _copy_index_to_pandas(self, df: pd.DataFrame):
428+
"""Set the index on pandas DataFrame to match this block.
429+
430+
Warning: This method modifies ``df`` inplace.
431+
"""
432+
if self.index_columns:
433+
df.set_index(list(self.index_columns), inplace=True)
434+
df.index.names = self.index.names # type: ignore
435+
415436
def _compute_and_count(
416437
self,
417438
value_keys: Optional[Iterable[str]] = None,
@@ -485,10 +506,7 @@ def _compute_and_count(
485506
else:
486507
total_rows = results_iterator.total_rows
487508
df = self._to_dataframe(results_iterator)
488-
489-
if self.index_columns:
490-
df.set_index(list(self.index_columns), inplace=True)
491-
df.index.names = self.index.names # type: ignore
509+
self._copy_index_to_pandas(df)
492510

493511
return df, total_rows, query_job
494512

bigframes/dataframe.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,10 @@ def to_pandas(
893893
self._set_internal_query_job(query_job)
894894
return df.set_axis(self._block.column_labels, axis=1, copy=False)
895895

896+
def to_pandas_batches(self) -> Iterable[pandas.DataFrame]:
897+
"""Stream DataFrame results to an iterable of pandas DataFrame"""
898+
return self._block.to_pandas_batches()
899+
896900
def _compute_dry_run(self) -> bigquery.QueryJob:
897901
return self._block._compute_dry_run()
898902

bigframes/session/_io/pandas.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@
2121
import bigframes.constants
2222

2323

24-
def arrow_to_pandas(arrow_table: pyarrow.Table, dtypes: Dict):
24+
def arrow_to_pandas(arrow_table: pyarrow.Table | pyarrow.RecordBatch, dtypes: Dict):
2525
if len(dtypes) != arrow_table.num_columns:
2626
raise ValueError(
2727
f"Number of types {len(dtypes)} doesn't match number of columns "
2828
f"{arrow_table.num_columns}. {bigframes.constants.FEEDBACK_LINK}"
2929
)
3030

3131
serieses = {}
32-
for column_name, column in zip(arrow_table.column_names, arrow_table):
33-
dtype = dtypes[column_name]
32+
for field, column in zip(arrow_table.schema, arrow_table):
33+
dtype = dtypes[field.name]
3434

3535
if dtype == geopandas.array.GeometryDtype():
3636
series = geopandas.GeoSeries.from_wkt(
@@ -41,6 +41,6 @@ def arrow_to_pandas(arrow_table: pyarrow.Table, dtypes: Dict):
4141
else:
4242
series = pandas.Series(column, dtype=dtype)
4343

44-
serieses[column_name] = series
44+
serieses[field.name] = series
4545

4646
return pandas.DataFrame(serieses)

tests/system/small/test_dataframe_io.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,14 @@ def test_to_pandas_array_struct_correct_result(session):
8383
)
8484

8585

86+
def test_to_pandas_batches_w_correct_dtypes(scalars_df_default_index):
87+
"""Verify to_pandas_batches() APIs returns the expected dtypes."""
88+
expected = scalars_df_default_index.dtypes
89+
for df in scalars_df_default_index.to_pandas_batches():
90+
actual = df.dtypes
91+
pd.testing.assert_series_equal(actual, expected)
92+
93+
8694
@pytest.mark.parametrize(
8795
("index"),
8896
[True, False],

tests/unit/session/test_io_pandas.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import datetime
16+
from typing import Dict
17+
18+
import geopandas # type: ignore
19+
import pandas
20+
import pandas.testing
21+
import pyarrow # type: ignore
22+
import pytest
23+
24+
import bigframes.session._io.pandas
25+
26+
27+
@pytest.mark.parametrize(
28+
("arrow_table", "dtypes", "expected"),
29+
(
30+
pytest.param(
31+
pyarrow.Table.from_pydict({}),
32+
{},
33+
pandas.DataFrame(),
34+
id="empty-df",
35+
),
36+
pytest.param(
37+
pyarrow.Table.from_pydict(
38+
{
39+
"bool": [True, None, False],
40+
"bytes": [b"123", None, b"abc"],
41+
"date": pyarrow.array(
42+
[datetime.date(2023, 8, 29), None, datetime.date(2024, 4, 9)],
43+
type=pyarrow.date32(),
44+
),
45+
"datetime": pyarrow.array(
46+
[
47+
datetime.datetime(2023, 8, 29),
48+
None,
49+
datetime.datetime(2024, 4, 9, 23, 59, 59),
50+
],
51+
type=pyarrow.timestamp("us"),
52+
),
53+
"string": ["123", None, "abc"],
54+
"time": pyarrow.array(
55+
[
56+
datetime.time(0, 0, 0, 1),
57+
None,
58+
datetime.time(23, 59, 59, 999999),
59+
],
60+
type=pyarrow.time64("us"),
61+
),
62+
"timestamp": pyarrow.array(
63+
[
64+
datetime.datetime(2023, 8, 29),
65+
None,
66+
datetime.datetime(2024, 4, 9, 23, 59, 59),
67+
],
68+
type=pyarrow.timestamp("us", datetime.timezone.utc),
69+
),
70+
}
71+
),
72+
{
73+
"bool": "boolean",
74+
"bytes": "object",
75+
"date": pandas.ArrowDtype(pyarrow.date32()),
76+
"datetime": pandas.ArrowDtype(pyarrow.timestamp("us")),
77+
"string": "string[pyarrow]",
78+
"time": pandas.ArrowDtype(pyarrow.time64("us")),
79+
"timestamp": pandas.ArrowDtype(
80+
pyarrow.timestamp("us", datetime.timezone.utc)
81+
),
82+
},
83+
pandas.DataFrame(
84+
{
85+
"bool": pandas.Series([True, None, False], dtype="boolean"),
86+
"bytes": [b"123", None, b"abc"],
87+
"date": pandas.Series(
88+
[datetime.date(2023, 8, 29), None, datetime.date(2024, 4, 9)],
89+
dtype=pandas.ArrowDtype(pyarrow.date32()),
90+
),
91+
"datetime": pandas.Series(
92+
[
93+
datetime.datetime(2023, 8, 29),
94+
None,
95+
datetime.datetime(2024, 4, 9, 23, 59, 59),
96+
],
97+
dtype=pandas.ArrowDtype(pyarrow.timestamp("us")),
98+
),
99+
"string": pandas.Series(
100+
["123", None, "abc"], dtype="string[pyarrow]"
101+
),
102+
"time": pandas.Series(
103+
[
104+
datetime.time(0, 0, 0, 1),
105+
None,
106+
datetime.time(23, 59, 59, 999999),
107+
],
108+
dtype=pandas.ArrowDtype(pyarrow.time64("us")),
109+
),
110+
"timestamp": pandas.Series(
111+
[
112+
datetime.datetime(2023, 8, 29),
113+
None,
114+
datetime.datetime(2024, 4, 9, 23, 59, 59),
115+
],
116+
dtype=pandas.ArrowDtype(
117+
pyarrow.timestamp("us", datetime.timezone.utc)
118+
),
119+
),
120+
}
121+
),
122+
id="scalar-dtypes",
123+
),
124+
pytest.param(
125+
pyarrow.Table.from_pydict(
126+
{
127+
"geocol": [
128+
"POINT(32 210)",
129+
None,
130+
"LINESTRING(1 1, 2 1, 3.1 2.88, 3 -3)",
131+
]
132+
}
133+
),
134+
{"geocol": geopandas.array.GeometryDtype()},
135+
pandas.DataFrame(
136+
{
137+
"geocol": geopandas.GeoSeries.from_wkt(
138+
["POINT(32 210)", None, "LINESTRING(1 1, 2 1, 3.1 2.88, 3 -3)"],
139+
crs="EPSG:4326",
140+
),
141+
}
142+
),
143+
id="geography-dtype",
144+
),
145+
),
146+
)
147+
def test_arrow_to_pandas(
148+
arrow_table: pyarrow.Table | pyarrow.RecordBatch,
149+
dtypes: Dict,
150+
expected: pandas.DataFrame,
151+
):
152+
actual = bigframes.session._io.pandas.arrow_to_pandas(arrow_table, dtypes)
153+
pandas.testing.assert_frame_equal(actual, expected)
154+
155+
156+
@pytest.mark.parametrize(
157+
("arrow_table", "dtypes"),
158+
(
159+
pytest.param(
160+
pyarrow.Table.from_pydict({"col1": [1], "col2": [2]}),
161+
{"col1": "Int64"},
162+
id="too-few-dtypes",
163+
),
164+
pytest.param(
165+
pyarrow.RecordBatch.from_pydict({"col1": [1]}),
166+
{"col1": "Int64", "col2": "string[pyarrow]"},
167+
id="too-many-dtypes",
168+
),
169+
),
170+
)
171+
def test_arrow_to_pandas_wrong_size_dtypes(
172+
arrow_table: pyarrow.Table | pyarrow.RecordBatch, dtypes: Dict
173+
):
174+
with pytest.raises(ValueError, match=f"Number of types {len(dtypes)}"):
175+
bigframes.session._io.pandas.arrow_to_pandas(arrow_table, dtypes)

0 commit comments

Comments
 (0)