Skip to content

Commit ae12843

Browse files
committed
use Index[Any] for __eq__ and __neq__
1 parent f39c5c4 commit ae12843

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

pandas-stubs/_libs/tslibs/timestamps.pyi

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ from datetime import (
99
import sys
1010
from time import struct_time
1111
from typing import (
12+
Any,
1213
ClassVar,
1314
Literal,
1415
SupportsIndex,
@@ -19,6 +20,7 @@ from _typing import TimeZones
1920
import numpy as np
2021
from pandas import (
2122
DatetimeIndex,
23+
Index,
2224
TimedeltaIndex,
2325
)
2426
from pandas.core.series import (
@@ -231,15 +233,15 @@ class Timestamp(datetime, SupportsIndex):
231233
@overload
232234
def __eq__(self, other: TimestampSeries) -> Series[bool]: ... # type: ignore[overload-overlap]
233235
@overload
234-
def __eq__(self, other: npt.NDArray[np.datetime64] | DatetimeIndex) -> np_ndarray_bool: ... # type: ignore[overload-overlap]
236+
def __eq__(self, other: npt.NDArray[np.datetime64] | Index[Any]) -> np_ndarray_bool: ... # type: ignore[overload-overlap]
235237
@overload
236238
def __eq__(self, other: object) -> Literal[False]: ...
237239
@overload
238240
def __ne__(self, other: Timestamp | datetime | np.datetime64) -> bool: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
239241
@overload
240242
def __ne__(self, other: TimestampSeries) -> Series[bool]: ... # type: ignore[overload-overlap]
241243
@overload
242-
def __ne__(self, other: npt.NDArray[np.datetime64] | DatetimeIndex) -> np_ndarray_bool: ... # type: ignore[overload-overlap]
244+
def __ne__(self, other: npt.NDArray[np.datetime64] | Index[Any]) -> np_ndarray_bool: ... # type: ignore[overload-overlap]
243245
@overload
244246
def __ne__(self, other: object) -> Literal[True]: ...
245247
def __hash__(self) -> int: ...

tests/test_scalars.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1325,6 +1325,8 @@ def test_timestamp_cmp() -> None:
13251325

13261326
check(assert_type(ts >= c_datetimeindex, np_ndarray_bool), np.ndarray, np.bool_)
13271327
check(assert_type(ts < c_datetimeindex, np_ndarray_bool), np.ndarray, np.bool_)
1328+
check(assert_type(ts >= c_unknown_index, np_ndarray_bool), np.ndarray, np.bool_)
1329+
check(assert_type(ts < c_unknown_index, np_ndarray_bool), np.ndarray, np.bool_)
13281330

13291331
check(assert_type(ts >= c_np_ndarray_dt64, np_ndarray_bool), np.ndarray, np.bool_)
13301332
check(assert_type(ts < c_np_ndarray_dt64, np_ndarray_bool), np.ndarray, np.bool_)
@@ -1370,6 +1372,13 @@ def test_timestamp_cmp() -> None:
13701372
assert_type(ts != c_datetimeindex, np_ndarray_bool), np.ndarray, np.bool_
13711373
)
13721374
assert (eq_arr != ne_arr).all()
1375+
eq_arr = check(
1376+
assert_type(ts == c_unknown_index, np_ndarray_bool), np.ndarray, np.bool_
1377+
)
1378+
ne_arr = check(
1379+
assert_type(ts != c_unknown_index, np_ndarray_bool), np.ndarray, np.bool_
1380+
)
1381+
assert (eq_arr != ne_arr).all()
13731382

13741383
eq_arr = check(
13751384
assert_type(ts == c_np_ndarray_dt64, np_ndarray_bool), np.ndarray, np.bool_

0 commit comments

Comments
 (0)