Skip to content

Commit 60325b8

Browse files
rhshadrachWillAyd
andauthored
ENH: Enable pytables to round-trip with StringDtype (#60663)
Co-authored-by: William Ayd <[email protected]>
1 parent 4c3b968 commit 60325b8

File tree

3 files changed

+87
-20
lines changed

3 files changed

+87
-20
lines changed

doc/source/whatsnew/v2.3.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ Other enhancements
3535
- The semantics for the ``copy`` keyword in ``__array__`` methods (i.e. called
3636
when using ``np.array()`` or ``np.asarray()`` on pandas objects) has been
3737
updated to work correctly with NumPy >= 2 (:issue:`57739`)
38+
- :meth:`~Series.to_hdf` and :meth:`~DataFrame.to_hdf` now round-trip with ``StringDtype`` (:issue:`60663`)
3839
- The :meth:`~Series.cumsum`, :meth:`~Series.cummin`, and :meth:`~Series.cummax` reductions are now implemented for ``StringDtype`` columns when backed by PyArrow (:issue:`60633`)
3940
- The :meth:`~Series.sum` reduction is now implemented for ``StringDtype`` columns (:issue:`59853`)
4041

pandas/io/pytables.py

+30-6
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,16 @@
8686
PeriodArray,
8787
)
8888
from pandas.core.arrays.datetimes import tz_to_dtype
89+
from pandas.core.arrays.string_ import BaseStringArray
8990
import pandas.core.common as com
9091
from pandas.core.computation.pytables import (
9192
PyTablesExpr,
9293
maybe_expression,
9394
)
94-
from pandas.core.construction import extract_array
95+
from pandas.core.construction import (
96+
array as pd_array,
97+
extract_array,
98+
)
9599
from pandas.core.indexes.api import ensure_index
96100

97101
from pandas.io.common import stringify_path
@@ -3023,6 +3027,9 @@ def read_array(self, key: str, start: int | None = None, stop: int | None = None
30233027

30243028
if isinstance(node, tables.VLArray):
30253029
ret = node[0][start:stop]
3030+
dtype = getattr(attrs, "value_type", None)
3031+
if dtype is not None:
3032+
ret = pd_array(ret, dtype=dtype)
30263033
else:
30273034
dtype = getattr(attrs, "value_type", None)
30283035
shape = getattr(attrs, "shape", None)
@@ -3262,6 +3269,11 @@ def write_array(
32623269
elif lib.is_np_dtype(value.dtype, "m"):
32633270
self._handle.create_array(self.group, key, value.view("i8"))
32643271
getattr(self.group, key)._v_attrs.value_type = "timedelta64"
3272+
elif isinstance(value, BaseStringArray):
3273+
vlarr = self._handle.create_vlarray(self.group, key, _tables().ObjectAtom())
3274+
vlarr.append(value.to_numpy())
3275+
node = getattr(self.group, key)
3276+
node._v_attrs.value_type = str(value.dtype)
32653277
elif empty_array:
32663278
self.write_array_empty(key, value)
32673279
else:
@@ -3294,7 +3306,11 @@ def read(
32943306
index = self.read_index("index", start=start, stop=stop)
32953307
values = self.read_array("values", start=start, stop=stop)
32963308
result = Series(values, index=index, name=self.name, copy=False)
3297-
if using_string_dtype() and is_string_array(values, skipna=True):
3309+
if (
3310+
using_string_dtype()
3311+
and isinstance(values, np.ndarray)
3312+
and is_string_array(values, skipna=True)
3313+
):
32983314
result = result.astype(StringDtype(na_value=np.nan))
32993315
return result
33003316

@@ -3363,7 +3379,11 @@ def read(
33633379

33643380
columns = items[items.get_indexer(blk_items)]
33653381
df = DataFrame(values.T, columns=columns, index=axes[1], copy=False)
3366-
if using_string_dtype() and is_string_array(values, skipna=True):
3382+
if (
3383+
using_string_dtype()
3384+
and isinstance(values, np.ndarray)
3385+
and is_string_array(values, skipna=True)
3386+
):
33673387
df = df.astype(StringDtype(na_value=np.nan))
33683388
dfs.append(df)
33693389

@@ -4737,9 +4757,13 @@ def read(
47374757
df = DataFrame._from_arrays([values], columns=cols_, index=index_)
47384758
if not (using_string_dtype() and values.dtype.kind == "O"):
47394759
assert (df.dtypes == values.dtype).all(), (df.dtypes, values.dtype)
4740-
if using_string_dtype() and is_string_array(
4741-
values, # type: ignore[arg-type]
4742-
skipna=True,
4760+
if (
4761+
using_string_dtype()
4762+
and isinstance(values, np.ndarray)
4763+
and is_string_array(
4764+
values,
4765+
skipna=True,
4766+
)
47434767
):
47444768
df = df.astype(StringDtype(na_value=np.nan))
47454769
frames.append(df)

pandas/tests/io/pytables/test_put.py

+56-14
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import numpy as np
44
import pytest
55

6-
from pandas._config import using_string_dtype
7-
86
from pandas._libs.tslibs import Timestamp
97

108
import pandas as pd
@@ -26,7 +24,6 @@
2624

2725
pytestmark = [
2826
pytest.mark.single_cpu,
29-
pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False),
3027
]
3128

3229

@@ -54,8 +51,8 @@ def test_api_default_format(tmp_path, setup_path):
5451
with ensure_clean_store(setup_path) as store:
5552
df = DataFrame(
5653
1.1 * np.arange(120).reshape((30, 4)),
57-
columns=Index(list("ABCD"), dtype=object),
58-
index=Index([f"i-{i}" for i in range(30)], dtype=object),
54+
columns=Index(list("ABCD")),
55+
index=Index([f"i-{i}" for i in range(30)]),
5956
)
6057

6158
with pd.option_context("io.hdf.default_format", "fixed"):
@@ -79,8 +76,8 @@ def test_api_default_format(tmp_path, setup_path):
7976
path = tmp_path / setup_path
8077
df = DataFrame(
8178
1.1 * np.arange(120).reshape((30, 4)),
82-
columns=Index(list("ABCD"), dtype=object),
83-
index=Index([f"i-{i}" for i in range(30)], dtype=object),
79+
columns=Index(list("ABCD")),
80+
index=Index([f"i-{i}" for i in range(30)]),
8481
)
8582

8683
with pd.option_context("io.hdf.default_format", "fixed"):
@@ -106,7 +103,7 @@ def test_put(setup_path):
106103
)
107104
df = DataFrame(
108105
np.random.default_rng(2).standard_normal((20, 4)),
109-
columns=Index(list("ABCD"), dtype=object),
106+
columns=Index(list("ABCD")),
110107
index=date_range("2000-01-01", periods=20, freq="B"),
111108
)
112109
store["a"] = ts
@@ -166,7 +163,7 @@ def test_put_compression(setup_path):
166163
with ensure_clean_store(setup_path) as store:
167164
df = DataFrame(
168165
np.random.default_rng(2).standard_normal((10, 4)),
169-
columns=Index(list("ABCD"), dtype=object),
166+
columns=Index(list("ABCD")),
170167
index=date_range("2000-01-01", periods=10, freq="B"),
171168
)
172169

@@ -183,7 +180,7 @@ def test_put_compression(setup_path):
183180
def test_put_compression_blosc(setup_path):
184181
df = DataFrame(
185182
np.random.default_rng(2).standard_normal((10, 4)),
186-
columns=Index(list("ABCD"), dtype=object),
183+
columns=Index(list("ABCD")),
187184
index=date_range("2000-01-01", periods=10, freq="B"),
188185
)
189186

@@ -197,10 +194,20 @@ def test_put_compression_blosc(setup_path):
197194
tm.assert_frame_equal(store["c"], df)
198195

199196

200-
def test_put_mixed_type(setup_path, performance_warning):
197+
def test_put_datetime_ser(setup_path, performance_warning, using_infer_string):
198+
# https://github.com/pandas-dev/pandas/pull/60663
199+
ser = Series(3 * [Timestamp("20010102").as_unit("ns")])
200+
with ensure_clean_store(setup_path) as store:
201+
store.put("ser", ser)
202+
expected = ser.copy()
203+
result = store.get("ser")
204+
tm.assert_series_equal(result, expected)
205+
206+
207+
def test_put_mixed_type(setup_path, performance_warning, using_infer_string):
201208
df = DataFrame(
202209
np.random.default_rng(2).standard_normal((10, 4)),
203-
columns=Index(list("ABCD"), dtype=object),
210+
columns=Index(list("ABCD")),
204211
index=date_range("2000-01-01", periods=10, freq="B"),
205212
)
206213
df["obj1"] = "foo"
@@ -220,13 +227,42 @@ def test_put_mixed_type(setup_path, performance_warning):
220227
with ensure_clean_store(setup_path) as store:
221228
_maybe_remove(store, "df")
222229

223-
with tm.assert_produces_warning(performance_warning):
230+
warning = None if using_infer_string else performance_warning
231+
with tm.assert_produces_warning(warning):
224232
store.put("df", df)
225233

226234
expected = store.get("df")
227235
tm.assert_frame_equal(expected, df)
228236

229237

238+
def test_put_str_frame(setup_path, performance_warning, string_dtype_arguments):
239+
# https://github.com/pandas-dev/pandas/pull/60663
240+
dtype = pd.StringDtype(*string_dtype_arguments)
241+
df = DataFrame({"a": pd.array(["x", pd.NA, "y"], dtype=dtype)})
242+
with ensure_clean_store(setup_path) as store:
243+
_maybe_remove(store, "df")
244+
245+
store.put("df", df)
246+
expected_dtype = "str" if dtype.na_value is np.nan else "string"
247+
expected = df.astype(expected_dtype)
248+
result = store.get("df")
249+
tm.assert_frame_equal(result, expected)
250+
251+
252+
def test_put_str_series(setup_path, performance_warning, string_dtype_arguments):
253+
# https://github.com/pandas-dev/pandas/pull/60663
254+
dtype = pd.StringDtype(*string_dtype_arguments)
255+
ser = Series(["x", pd.NA, "y"], dtype=dtype)
256+
with ensure_clean_store(setup_path) as store:
257+
_maybe_remove(store, "df")
258+
259+
store.put("ser", ser)
260+
expected_dtype = "str" if dtype.na_value is np.nan else "string"
261+
expected = ser.astype(expected_dtype)
262+
result = store.get("ser")
263+
tm.assert_series_equal(result, expected)
264+
265+
230266
@pytest.mark.parametrize("format", ["table", "fixed"])
231267
@pytest.mark.parametrize(
232268
"index",
@@ -253,7 +289,7 @@ def test_store_index_types(setup_path, format, index):
253289
tm.assert_frame_equal(df, store["df"])
254290

255291

256-
def test_column_multiindex(setup_path):
292+
def test_column_multiindex(setup_path, using_infer_string):
257293
# GH 4710
258294
# recreate multi-indexes properly
259295

@@ -264,6 +300,12 @@ def test_column_multiindex(setup_path):
264300
expected = df.set_axis(df.index.to_numpy())
265301

266302
with ensure_clean_store(setup_path) as store:
303+
if using_infer_string:
304+
# TODO(infer_string) make this work for string dtype
305+
msg = "Saving a MultiIndex with an extension dtype is not supported."
306+
with pytest.raises(NotImplementedError, match=msg):
307+
store.put("df", df)
308+
return
267309
store.put("df", df)
268310
tm.assert_frame_equal(
269311
store["df"], expected, check_index_type=True, check_column_type=True

0 commit comments

Comments
 (0)