3
3
import numpy as np
4
4
import pytest
5
5
6
- from pandas ._config import using_string_dtype
7
-
8
6
from pandas ._libs .tslibs import Timestamp
9
7
10
8
import pandas as pd
26
24
27
25
pytestmark = [
28
26
pytest .mark .single_cpu ,
29
- pytest .mark .xfail (using_string_dtype (), reason = "TODO(infer_string)" , strict = False ),
30
27
]
31
28
32
29
@@ -54,8 +51,8 @@ def test_api_default_format(tmp_path, setup_path):
54
51
with ensure_clean_store (setup_path ) as store :
55
52
df = DataFrame (
56
53
1.1 * np .arange (120 ).reshape ((30 , 4 )),
57
- columns = Index (list ("ABCD" ), dtype = object ),
58
- index = Index ([f"i-{ i } " for i in range (30 )], dtype = object ),
54
+ columns = Index (list ("ABCD" )),
55
+ index = Index ([f"i-{ i } " for i in range (30 )]),
59
56
)
60
57
61
58
with pd .option_context ("io.hdf.default_format" , "fixed" ):
@@ -79,8 +76,8 @@ def test_api_default_format(tmp_path, setup_path):
79
76
path = tmp_path / setup_path
80
77
df = DataFrame (
81
78
1.1 * np .arange (120 ).reshape ((30 , 4 )),
82
- columns = Index (list ("ABCD" ), dtype = object ),
83
- index = Index ([f"i-{ i } " for i in range (30 )], dtype = object ),
79
+ columns = Index (list ("ABCD" )),
80
+ index = Index ([f"i-{ i } " for i in range (30 )]),
84
81
)
85
82
86
83
with pd .option_context ("io.hdf.default_format" , "fixed" ):
@@ -106,7 +103,7 @@ def test_put(setup_path):
106
103
)
107
104
df = DataFrame (
108
105
np .random .default_rng (2 ).standard_normal ((20 , 4 )),
109
- columns = Index (list ("ABCD" ), dtype = object ),
106
+ columns = Index (list ("ABCD" )),
110
107
index = date_range ("2000-01-01" , periods = 20 , freq = "B" ),
111
108
)
112
109
store ["a" ] = ts
@@ -166,7 +163,7 @@ def test_put_compression(setup_path):
166
163
with ensure_clean_store (setup_path ) as store :
167
164
df = DataFrame (
168
165
np .random .default_rng (2 ).standard_normal ((10 , 4 )),
169
- columns = Index (list ("ABCD" ), dtype = object ),
166
+ columns = Index (list ("ABCD" )),
170
167
index = date_range ("2000-01-01" , periods = 10 , freq = "B" ),
171
168
)
172
169
@@ -183,7 +180,7 @@ def test_put_compression(setup_path):
183
180
def test_put_compression_blosc (setup_path ):
184
181
df = DataFrame (
185
182
np .random .default_rng (2 ).standard_normal ((10 , 4 )),
186
- columns = Index (list ("ABCD" ), dtype = object ),
183
+ columns = Index (list ("ABCD" )),
187
184
index = date_range ("2000-01-01" , periods = 10 , freq = "B" ),
188
185
)
189
186
@@ -197,10 +194,20 @@ def test_put_compression_blosc(setup_path):
197
194
tm .assert_frame_equal (store ["c" ], df )
198
195
199
196
200
- def test_put_mixed_type (setup_path , performance_warning ):
197
+ def test_put_datetime_ser (setup_path , performance_warning , using_infer_string ):
198
+ # https://github.com/pandas-dev/pandas/pull/60663
199
+ ser = Series (3 * [Timestamp ("20010102" ).as_unit ("ns" )])
200
+ with ensure_clean_store (setup_path ) as store :
201
+ store .put ("ser" , ser )
202
+ expected = ser .copy ()
203
+ result = store .get ("ser" )
204
+ tm .assert_series_equal (result , expected )
205
+
206
+
207
+ def test_put_mixed_type (setup_path , performance_warning , using_infer_string ):
201
208
df = DataFrame (
202
209
np .random .default_rng (2 ).standard_normal ((10 , 4 )),
203
- columns = Index (list ("ABCD" ), dtype = object ),
210
+ columns = Index (list ("ABCD" )),
204
211
index = date_range ("2000-01-01" , periods = 10 , freq = "B" ),
205
212
)
206
213
df ["obj1" ] = "foo"
@@ -220,13 +227,42 @@ def test_put_mixed_type(setup_path, performance_warning):
220
227
with ensure_clean_store (setup_path ) as store :
221
228
_maybe_remove (store , "df" )
222
229
223
- with tm .assert_produces_warning (performance_warning ):
230
+ warning = None if using_infer_string else performance_warning
231
+ with tm .assert_produces_warning (warning ):
224
232
store .put ("df" , df )
225
233
226
234
expected = store .get ("df" )
227
235
tm .assert_frame_equal (expected , df )
228
236
229
237
238
+ def test_put_str_frame (setup_path , performance_warning , string_dtype_arguments ):
239
+ # https://github.com/pandas-dev/pandas/pull/60663
240
+ dtype = pd .StringDtype (* string_dtype_arguments )
241
+ df = DataFrame ({"a" : pd .array (["x" , pd .NA , "y" ], dtype = dtype )})
242
+ with ensure_clean_store (setup_path ) as store :
243
+ _maybe_remove (store , "df" )
244
+
245
+ store .put ("df" , df )
246
+ expected_dtype = "str" if dtype .na_value is np .nan else "string"
247
+ expected = df .astype (expected_dtype )
248
+ result = store .get ("df" )
249
+ tm .assert_frame_equal (result , expected )
250
+
251
+
252
+ def test_put_str_series (setup_path , performance_warning , string_dtype_arguments ):
253
+ # https://github.com/pandas-dev/pandas/pull/60663
254
+ dtype = pd .StringDtype (* string_dtype_arguments )
255
+ ser = Series (["x" , pd .NA , "y" ], dtype = dtype )
256
+ with ensure_clean_store (setup_path ) as store :
257
+ _maybe_remove (store , "df" )
258
+
259
+ store .put ("ser" , ser )
260
+ expected_dtype = "str" if dtype .na_value is np .nan else "string"
261
+ expected = ser .astype (expected_dtype )
262
+ result = store .get ("ser" )
263
+ tm .assert_series_equal (result , expected )
264
+
265
+
230
266
@pytest .mark .parametrize ("format" , ["table" , "fixed" ])
231
267
@pytest .mark .parametrize (
232
268
"index" ,
@@ -253,7 +289,7 @@ def test_store_index_types(setup_path, format, index):
253
289
tm .assert_frame_equal (df , store ["df" ])
254
290
255
291
256
- def test_column_multiindex (setup_path ):
292
+ def test_column_multiindex (setup_path , using_infer_string ):
257
293
# GH 4710
258
294
# recreate multi-indexes properly
259
295
@@ -264,6 +300,12 @@ def test_column_multiindex(setup_path):
264
300
expected = df .set_axis (df .index .to_numpy ())
265
301
266
302
with ensure_clean_store (setup_path ) as store :
303
+ if using_infer_string :
304
+ # TODO(infer_string) make this work for string dtype
305
+ msg = "Saving a MultiIndex with an extension dtype is not supported."
306
+ with pytest .raises (NotImplementedError , match = msg ):
307
+ store .put ("df" , df )
308
+ return
267
309
store .put ("df" , df )
268
310
tm .assert_frame_equal (
269
311
store ["df" ], expected , check_index_type = True , check_column_type = True
0 commit comments