Skip to content

Commit ff28a3e

Browse files
authored
API (string): return str dtype for .dt methods, DatetimeIndex methods (#59526)
* API (string): return str dtype for .dt methods, DatetimeIndex methods * mypy fixup
1 parent 96a7462 commit ff28a3e

File tree

7 files changed

+43
-17
lines changed

7 files changed

+43
-17
lines changed

pandas/core/arrays/datetimelike.py

+5
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import numpy as np
2121

22+
from pandas._config import using_string_dtype
2223
from pandas._config.config import get_option
2324

2425
from pandas._libs import (
@@ -1759,6 +1760,10 @@ def strftime(self, date_format: str) -> npt.NDArray[np.object_]:
17591760
dtype='object')
17601761
"""
17611762
result = self._format_native_types(date_format=date_format, na_rep=np.nan)
1763+
if using_string_dtype():
1764+
from pandas import StringDtype
1765+
1766+
return pd_array(result, dtype=StringDtype(na_value=np.nan)) # type: ignore[return-value]
17621767
return result.astype(object, copy=False)
17631768

17641769

pandas/core/arrays/datetimes.py

+16
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import numpy as np
1717

18+
from pandas._config import using_string_dtype
1819
from pandas._config.config import get_option
1920

2021
from pandas._libs import (
@@ -1332,6 +1333,13 @@ def month_name(self, locale=None) -> npt.NDArray[np.object_]:
13321333
values, "month_name", locale=locale, reso=self._creso
13331334
)
13341335
result = self._maybe_mask_results(result, fill_value=None)
1336+
if using_string_dtype():
1337+
from pandas import (
1338+
StringDtype,
1339+
array as pd_array,
1340+
)
1341+
1342+
return pd_array(result, dtype=StringDtype(na_value=np.nan)) # type: ignore[return-value]
13351343
return result
13361344

13371345
def day_name(self, locale=None) -> npt.NDArray[np.object_]:
@@ -1393,6 +1401,14 @@ def day_name(self, locale=None) -> npt.NDArray[np.object_]:
13931401
values, "day_name", locale=locale, reso=self._creso
13941402
)
13951403
result = self._maybe_mask_results(result, fill_value=None)
1404+
if using_string_dtype():
1405+
# TODO: no tests that check for dtype of result as of 2024-08-15
1406+
from pandas import (
1407+
StringDtype,
1408+
array as pd_array,
1409+
)
1410+
1411+
return pd_array(result, dtype=StringDtype(na_value=np.nan)) # type: ignore[return-value]
13961412
return result
13971413

13981414
@property

pandas/core/indexes/datetimes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def _engine_type(self) -> type[libindex.DatetimeEngine]:
263263
@doc(DatetimeArray.strftime)
264264
def strftime(self, date_format) -> Index:
265265
arr = self._data.strftime(date_format)
266-
return Index(arr, name=self.name, dtype=object)
266+
return Index(arr, name=self.name, dtype=arr.dtype)
267267

268268
@doc(DatetimeArray.tz_convert)
269269
def tz_convert(self, tz) -> Self:

pandas/core/indexes/extension.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def fget(self):
7474
return type(self)._simple_new(result, name=self.name)
7575
elif isinstance(result, ABCDataFrame):
7676
return result.set_index(self)
77-
return Index(result, name=self.name)
77+
return Index(result, name=self.name, dtype=result.dtype)
7878
return result
7979

8080
def fset(self, value) -> None:
@@ -101,7 +101,7 @@ def method(self, *args, **kwargs): # type: ignore[misc]
101101
return type(self)._simple_new(result, name=self.name)
102102
elif isinstance(result, ABCDataFrame):
103103
return result.set_index(self)
104-
return Index(result, name=self.name)
104+
return Index(result, name=self.name, dtype=result.dtype)
105105
return result
106106

107107
# error: "property" has no attribute "__name__"

pandas/tests/arrays/test_datetimelike.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -891,20 +891,24 @@ def test_concat_same_type_different_freq(self, unit):
891891

892892
tm.assert_datetime_array_equal(result, expected)
893893

894-
def test_strftime(self, arr1d):
894+
def test_strftime(self, arr1d, using_infer_string):
895895
arr = arr1d
896896

897897
result = arr.strftime("%Y %b")
898898
expected = np.array([ts.strftime("%Y %b") for ts in arr], dtype=object)
899-
tm.assert_numpy_array_equal(result, expected)
899+
if using_infer_string:
900+
expected = pd.array(expected, dtype=pd.StringDtype(na_value=np.nan))
901+
tm.assert_equal(result, expected)
900902

901-
def test_strftime_nat(self):
903+
def test_strftime_nat(self, using_infer_string):
902904
# GH 29578
903905
arr = DatetimeIndex(["2019-01-01", NaT])._data
904906

905907
result = arr.strftime("%Y-%m-%d")
906908
expected = np.array(["2019-01-01", np.nan], dtype=object)
907-
tm.assert_numpy_array_equal(result, expected)
909+
if using_infer_string:
910+
expected = pd.array(expected, dtype=pd.StringDtype(na_value=np.nan))
911+
tm.assert_equal(result, expected)
908912

909913

910914
class TestTimedeltaArray(SharedTests):
@@ -1161,20 +1165,24 @@ def test_array_interface(self, arr1d):
11611165
expected = np.asarray(arr).astype("S20")
11621166
tm.assert_numpy_array_equal(result, expected)
11631167

1164-
def test_strftime(self, arr1d):
1168+
def test_strftime(self, arr1d, using_infer_string):
11651169
arr = arr1d
11661170

11671171
result = arr.strftime("%Y")
11681172
expected = np.array([per.strftime("%Y") for per in arr], dtype=object)
1169-
tm.assert_numpy_array_equal(result, expected)
1173+
if using_infer_string:
1174+
expected = pd.array(expected, dtype=pd.StringDtype(na_value=np.nan))
1175+
tm.assert_equal(result, expected)
11701176

1171-
def test_strftime_nat(self):
1177+
def test_strftime_nat(self, using_infer_string):
11721178
# GH 29578
11731179
arr = PeriodArray(PeriodIndex(["2019-01-01", NaT], dtype="period[D]"))
11741180

11751181
result = arr.strftime("%Y-%m-%d")
11761182
expected = np.array(["2019-01-01", np.nan], dtype=object)
1177-
tm.assert_numpy_array_equal(result, expected)
1183+
if using_infer_string:
1184+
expected = pd.array(expected, dtype=pd.StringDtype(na_value=np.nan))
1185+
tm.assert_equal(result, expected)
11781186

11791187

11801188
@pytest.mark.parametrize(

pandas/tests/io/excel/test_writers.py

-1
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,6 @@ def test_excel_multindex_roundtrip(
282282
)
283283
tm.assert_frame_equal(df, act, check_names=check_names)
284284

285-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
286285
def test_read_excel_parse_dates(self, tmp_excel):
287286
# see gh-11544, gh-12051
288287
df = DataFrame(

pandas/tests/series/accessors/test_dt_accessor.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
Period,
2828
PeriodIndex,
2929
Series,
30+
StringDtype,
3031
TimedeltaIndex,
3132
date_range,
3233
period_range,
@@ -513,7 +514,6 @@ def test_dt_accessor_datetime_name_accessors(self, time_locale):
513514
ser = pd.concat([ser, Series([pd.NaT])])
514515
assert np.isnan(ser.dt.month_name(locale=time_locale).iloc[-1])
515516

516-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
517517
def test_strftime(self):
518518
# GH 10086
519519
ser = Series(date_range("20130101", periods=5))
@@ -584,10 +584,9 @@ def test_strftime_period_days(self, using_infer_string):
584584
dtype="=U10",
585585
)
586586
if using_infer_string:
587-
expected = expected.astype("str")
587+
expected = expected.astype(StringDtype(na_value=np.nan))
588588
tm.assert_index_equal(result, expected)
589589

590-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
591590
def test_strftime_dt64_microsecond_resolution(self):
592591
ser = Series([datetime(2013, 1, 1, 2, 32, 59), datetime(2013, 1, 2, 14, 32, 1)])
593592
result = ser.dt.strftime("%Y-%m-%d %H:%M:%S")
@@ -620,7 +619,6 @@ def test_strftime_period_minutes(self):
620619
)
621620
tm.assert_series_equal(result, expected)
622621

623-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
624622
@pytest.mark.parametrize(
625623
"data",
626624
[
@@ -643,7 +641,7 @@ def test_strftime_all_nat(self, data):
643641
ser = Series(data)
644642
with tm.assert_produces_warning(None):
645643
result = ser.dt.strftime("%Y-%m-%d")
646-
expected = Series([np.nan], dtype=object)
644+
expected = Series([np.nan], dtype="str")
647645
tm.assert_series_equal(result, expected)
648646

649647
def test_valid_dt_with_missing_values(self):

0 commit comments

Comments
 (0)