diff --git a/pandas/io/pytables.py b/pandas/io/pytables.py index b75dc6c3a43b4..7e227d37c74d6 100644 --- a/pandas/io/pytables.py +++ b/pandas/io/pytables.py @@ -86,6 +86,7 @@ PeriodArray, ) from pandas.core.arrays.datetimes import tz_to_dtype +from pandas.core.arrays.string_ import BaseStringArray import pandas.core.common as com from pandas.core.computation.pytables import ( PyTablesExpr, @@ -3185,6 +3186,8 @@ def write_array( # both self._filters and EA value = extract_array(obj, extract_numpy=True) + if isinstance(value, BaseStringArray): + value = value.to_numpy() if key in self.group: self._handle.remove_node(self.group, key) @@ -3294,7 +3297,11 @@ def read( index = self.read_index("index", start=start, stop=stop) values = self.read_array("values", start=start, stop=stop) result = Series(values, index=index, name=self.name, copy=False) - if using_string_dtype() and is_string_array(values, skipna=True): + if ( + using_string_dtype() + and isinstance(values, np.ndarray) + and is_string_array(values, skipna=True) + ): result = result.astype(StringDtype(na_value=np.nan)) return result @@ -3363,7 +3370,11 @@ def read( columns = items[items.get_indexer(blk_items)] df = DataFrame(values.T, columns=columns, index=axes[1], copy=False) - if using_string_dtype() and is_string_array(values, skipna=True): + if ( + using_string_dtype() + and isinstance(values, np.ndarray) + and is_string_array(values, skipna=True) + ): df = df.astype(StringDtype(na_value=np.nan)) dfs.append(df) @@ -4737,9 +4748,10 @@ def read( df = DataFrame._from_arrays([values], columns=cols_, index=index_) if not (using_string_dtype() and values.dtype.kind == "O"): assert (df.dtypes == values.dtype).all(), (df.dtypes, values.dtype) - if using_string_dtype() and is_string_array( - values, # type: ignore[arg-type] - skipna=True, + if ( + using_string_dtype() + and isinstance(values, np.ndarray) + and is_string_array(values, skipna=True) ): df = df.astype(StringDtype(na_value=np.nan)) frames.append(df) diff --git a/pandas/tests/io/pytables/test_put.py b/pandas/tests/io/pytables/test_put.py index a4257b54dd6db..9bfb10699c8b5 100644 --- a/pandas/tests/io/pytables/test_put.py +++ b/pandas/tests/io/pytables/test_put.py @@ -3,8 +3,6 @@ import numpy as np import pytest -from pandas._config import using_string_dtype - from pandas._libs.tslibs import Timestamp import pandas as pd @@ -26,7 +24,6 @@ pytestmark = [ pytest.mark.single_cpu, - pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False), ] @@ -106,7 +103,7 @@ def test_put(setup_path): ) df = DataFrame( np.random.default_rng(2).standard_normal((20, 4)), - columns=Index(list("ABCD"), dtype=object), + columns=Index(list("ABCD")), index=date_range("2000-01-01", periods=20, freq="B"), ) store["a"] = ts @@ -133,7 +130,9 @@ def test_put(setup_path): # overwrite table store.put("c", df[:10], format="table", append=False) - tm.assert_frame_equal(df[:10], store["c"]) + expected = df[:10] + result = store["c"] + tm.assert_frame_equal(result, expected) def test_put_string_index(setup_path): @@ -166,12 +165,14 @@ def test_put_compression(setup_path): with ensure_clean_store(setup_path) as store: df = DataFrame( np.random.default_rng(2).standard_normal((10, 4)), - columns=Index(list("ABCD"), dtype=object), + columns=Index(list("ABCD")), index=date_range("2000-01-01", periods=10, freq="B"), ) store.put("c", df, format="table", complib="zlib") - tm.assert_frame_equal(store["c"], df) + expected = df + result = store["c"] + tm.assert_frame_equal(result, expected) # can't compress if format='fixed' msg = "Compression not supported on Fixed format stores" @@ -183,7 +184,7 @@ def test_put_compression(setup_path): def test_put_compression_blosc(setup_path): df = DataFrame( np.random.default_rng(2).standard_normal((10, 4)), - columns=Index(list("ABCD"), dtype=object), + columns=Index(list("ABCD")), index=date_range("2000-01-01", periods=10, freq="B"), ) @@ -194,17 +195,29 @@ def test_put_compression_blosc(setup_path): store.put("b", df, format="fixed", complib="blosc") store.put("c", df, format="table", complib="blosc") - tm.assert_frame_equal(store["c"], df) + expected = df + result = store["c"] + tm.assert_frame_equal(result, expected) -def test_put_mixed_type(setup_path, performance_warning): +def test_put_datetime_ser(setup_path, performance_warning, using_infer_string): + # https://github.com/pandas-dev/pandas/pull/60625 + ser = Series(3 * [Timestamp("20010102").as_unit("ns")]) + with ensure_clean_store(setup_path) as store: + store.put("ser", ser) + expected = ser.copy() + result = store.get("ser") + tm.assert_series_equal(result, expected) + + +def test_put_mixed_type(setup_path, performance_warning, using_infer_string): df = DataFrame( np.random.default_rng(2).standard_normal((10, 4)), - columns=Index(list("ABCD"), dtype=object), + columns=Index(list("ABCD")), index=date_range("2000-01-01", periods=10, freq="B"), ) df["obj1"] = "foo" - df["obj2"] = "bar" + df["obj2"] = pd.array([pd.NA] + 9 * ["bar"], dtype="string") df["bool1"] = df["A"] > 0 df["bool2"] = df["B"] > 0 df["bool3"] = True @@ -223,8 +236,13 @@ def test_put_mixed_type(setup_path, performance_warning): with tm.assert_produces_warning(performance_warning): store.put("df", df) - expected = store.get("df") - tm.assert_frame_equal(expected, df) + expected = df.copy() + if using_infer_string: + expected["obj2"] = expected["obj2"].astype("str") + else: + expected["obj2"] = expected["obj2"].astype("object") + result = store.get("df") + tm.assert_frame_equal(result, expected) @pytest.mark.parametrize("format", ["table", "fixed"]) @@ -253,7 +271,7 @@ def test_store_index_types(setup_path, format, index): tm.assert_frame_equal(df, store["df"]) -def test_column_multiindex(setup_path): +def test_column_multiindex(setup_path, using_infer_string): # GH 4710 # recreate multi-indexes properly @@ -264,6 +282,12 @@ def test_column_multiindex(setup_path): expected = df.set_axis(df.index.to_numpy()) with ensure_clean_store(setup_path) as store: + if using_infer_string: + # TODO(infer_string) make this work for string dtype + msg = "Saving a MultiIndex with an extension dtype is not supported." + with pytest.raises(NotImplementedError, match=msg): + store.put("df", df) + return store.put("df", df) tm.assert_frame_equal( store["df"], expected, check_index_type=True, check_column_type=True